diff options
author | Kali Kaneko <kali@leap.se> | 2013-07-02 22:29:32 +0900 |
---|---|---|
committer | Kali Kaneko <kali@leap.se> | 2013-07-02 22:29:32 +0900 |
commit | 1d30e2580592ef905d9b21c475459da1e40b1cd6 (patch) | |
tree | 6ec48c6b5234da55ecc91ad3c6235fb20b61315c /src/leap/services/eip | |
parent | 81dc8ebe9ef46c0fafa75cba5c4959bb822da686 (diff) | |
parent | 5b975799ce9b7a6e0a88be4bcb48bdfb90800bb3 (diff) |
Merge branch 'master' of ssh://leap.se/leap_client
Diffstat (limited to 'src/leap/services/eip')
-rw-r--r-- | src/leap/services/eip/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/services/eip/eipbootstrapper.py | 178 | ||||
-rw-r--r-- | src/leap/services/eip/eipconfig.py | 260 | ||||
-rw-r--r-- | src/leap/services/eip/eipspec.py | 65 | ||||
-rw-r--r-- | src/leap/services/eip/providerbootstrapper.py | 311 | ||||
-rw-r--r-- | src/leap/services/eip/tests/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/services/eip/tests/test_eipbootstrapper.py | 347 | ||||
-rw-r--r-- | src/leap/services/eip/tests/test_eipconfig.py | 313 | ||||
-rw-r--r-- | src/leap/services/eip/tests/test_providerbootstrapper.py | 504 | ||||
-rw-r--r-- | src/leap/services/eip/tests/test_vpngatewayselector.py | 131 | ||||
-rw-r--r-- | src/leap/services/eip/tests/wrongcert.pem | 33 | ||||
-rw-r--r-- | src/leap/services/eip/udstelnet.py | 60 | ||||
-rw-r--r-- | src/leap/services/eip/vpnlaunchers.py | 774 | ||||
-rw-r--r-- | src/leap/services/eip/vpnprocess.py | 718 |
14 files changed, 3694 insertions, 0 deletions
diff --git a/src/leap/services/eip/__init__.py b/src/leap/services/eip/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/services/eip/__init__.py diff --git a/src/leap/services/eip/eipbootstrapper.py b/src/leap/services/eip/eipbootstrapper.py new file mode 100644 index 00000000..b2af0aea --- /dev/null +++ b/src/leap/services/eip/eipbootstrapper.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# eipbootstrapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +EIP bootstrapping +""" + +import logging +import os + +from PySide import QtCore + +from leap.common.check import leap_assert, leap_assert_type +from leap.common import certs +from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p +from leap.config.providerconfig import ProviderConfig +from leap.crypto.srpauth import SRPAuth +from leap.services.eip.eipconfig import EIPConfig +from leap.util.request_helpers import get_content +from leap.services.abstractbootstrapper import AbstractBootstrapper + +logger = logging.getLogger(__name__) + + +class EIPBootstrapper(AbstractBootstrapper): + """ + Sets up EIP for a provider a series of checks and emits signals + after they are passed. + If a check fails, the subsequent checks are not executed + """ + + # All dicts returned are of the form + # {"passed": bool, "error": str} + download_config = QtCore.Signal(dict) + download_client_certificate = QtCore.Signal(dict) + + def __init__(self): + AbstractBootstrapper.__init__(self) + + self._provider_config = None + self._eip_config = None + self._download_if_needed = False + + def _download_config(self, *args): + """ + Downloads the EIP config for the given provider + """ + + leap_assert(self._provider_config, + "We need a provider configuration!") + + logger.debug("Downloading EIP config for %s" % + (self._provider_config.get_domain(),)) + + self._eip_config = EIPConfig() + + headers = {} + mtime = get_mtime(os.path.join(self._eip_config + .get_path_prefix(), + "leap", + "providers", + self._provider_config.get_domain(), + "eip-service.json")) + + if self._download_if_needed and mtime: + headers['if-modified-since'] = mtime + + # there is some confusion with this uri, + # it's in 1/config/eip, config/eip and config/1/eip... + config_uri = "%s/%s/config/eip-service.json" % ( + self._provider_config.get_api_uri(), + self._provider_config.get_api_version()) + logger.debug('Downloading eip config from: %s' % config_uri) + + res = self._session.get(config_uri, + verify=self._provider_config + .get_ca_cert_path(), + headers=headers) + res.raise_for_status() + + # Not modified + if res.status_code == 304: + logger.debug("EIP definition has not been modified") + else: + eip_definition, mtime = get_content(res) + + self._eip_config.load(data=eip_definition, mtime=mtime) + self._eip_config.save(["leap", + "providers", + self._provider_config.get_domain(), + "eip-service.json"]) + + def _download_client_certificates(self, *args): + """ + Downloads the EIP client certificate for the given provider + """ + leap_assert(self._provider_config, "We need a provider configuration!") + leap_assert(self._eip_config, "We need an eip configuration!") + + logger.debug("Downloading EIP client certificate for %s" % + (self._provider_config.get_domain(),)) + + client_cert_path = self._eip_config.\ + get_client_cert_path(self._provider_config, + about_to_download=True) + + # For re-download if something is wrong with the cert + self._download_if_needed = self._download_if_needed and \ + not certs.should_redownload(client_cert_path) + + if self._download_if_needed and \ + os.path.exists(client_cert_path): + check_and_fix_urw_only(client_cert_path) + return + + srp_auth = SRPAuth(self._provider_config) + session_id = srp_auth.get_session_id() + cookies = None + if session_id: + cookies = {"_session_id": session_id} + cert_uri = "%s/%s/cert" % ( + self._provider_config.get_api_uri(), + self._provider_config.get_api_version()) + logger.debug('getting cert from uri: %s' % cert_uri) + res = self._session.get(cert_uri, + verify=self._provider_config + .get_ca_cert_path(), + cookies=cookies) + res.raise_for_status() + client_cert = res.content + + if not certs.is_valid_pemfile(client_cert): + raise Exception(self.tr("The downloaded certificate is not a " + "valid PEM file")) + + mkdir_p(os.path.dirname(client_cert_path)) + + with open(client_cert_path, "w") as f: + f.write(client_cert) + + check_and_fix_urw_only(client_cert_path) + + def run_eip_setup_checks(self, + provider_config, + download_if_needed=False): + """ + Starts the checks needed for a new eip setup + + :param provider_config: Provider configuration + :type provider_config: ProviderConfig + """ + leap_assert(provider_config, "We need a provider config!") + leap_assert_type(provider_config, ProviderConfig) + + self._provider_config = provider_config + self._download_if_needed = download_if_needed + + cb_chain = [ + (self._download_config, self.download_config), + (self._download_client_certificates, + self.download_client_certificate) + ] + + return self.addCallbackChain(cb_chain) diff --git a/src/leap/services/eip/eipconfig.py b/src/leap/services/eip/eipconfig.py new file mode 100644 index 00000000..9e3a9b29 --- /dev/null +++ b/src/leap/services/eip/eipconfig.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +# eipconfig.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +Provider configuration +""" +import logging +import os +import re +import time + +import ipaddr + +from leap.common.check import leap_assert, leap_assert_type +from leap.common.config.baseconfig import BaseConfig +from leap.config.providerconfig import ProviderConfig +from leap.services.eip.eipspec import eipservice_config_spec + +logger = logging.getLogger(__name__) + + +class VPNGatewaySelector(object): + """ + VPN Gateway selector. + """ + # http://www.timeanddate.com/time/map/ + equivalent_timezones = {13: -11, 14: -10} + + def __init__(self, eipconfig, tz_offset=None): + ''' + Constructor for VPNGatewaySelector. + + :param eipconfig: a valid EIP Configuration. + :type eipconfig: EIPConfig + :param tz_offset: use this offset as a local distance to GMT. + :type tz_offset: int + ''' + leap_assert_type(eipconfig, EIPConfig) + + self._local_offset = tz_offset + if tz_offset is None: + tz_offset = self._get_local_offset() + + if tz_offset in self.equivalent_timezones: + tz_offset = self.equivalent_timezones[tz_offset] + + self._local_offset = tz_offset + + self._eipconfig = eipconfig + + def get_gateways(self): + """ + Returns the 4 best gateways, sorted by timezone proximity. + + :rtype: list of IPv4Address or IPv6Address object. + """ + gateways_timezones = [] + locations = self._eipconfig.get_locations() + gateways = self._eipconfig.get_gateways() + + for idx, gateway in enumerate(gateways): + gateway_location = gateway.get('location') + gateway_distance = 99 # if hasn't location -> should go last + + if gateway_location is not None: + gw_offset = int(locations[gateway['location']]['timezone']) + if gw_offset in self.equivalent_timezones: + gw_offset = self.equivalent_timezones[gw_offset] + + gateway_distance = self._get_timezone_distance(gw_offset) + + ip = self._eipconfig.get_gateway_ip(idx) + gateways_timezones.append((ip, gateway_distance)) + + gateways_timezones = sorted(gateways_timezones, + key=lambda gw: gw[1])[:4] + + gateways = [ip for ip, dist in gateways_timezones] + return gateways + + def _get_timezone_distance(self, offset): + ''' + Returns the distance between the local timezone and + the one with offset 'offset'. + + :param offset: the distance of a timezone to GMT. + :type offset: int + :returns: distance between local offset and param offset. + :rtype: int + ''' + timezones = range(-11, 13) + tz1 = offset + tz2 = self._local_offset + distance = abs(timezones.index(tz1) - timezones.index(tz2)) + if distance > 12: + if tz1 < 0: + distance = timezones.index(tz1) + timezones[::-1].index(tz2) + else: + distance = timezones[::-1].index(tz1) + timezones.index(tz2) + + return distance + + def _get_local_offset(self): + ''' + Returns the distance between GMT and the local timezone. + + :rtype: int + ''' + local_offset = time.timezone + if time.daylight: + local_offset = time.altzone + + return local_offset / 3600 + + +class EIPConfig(BaseConfig): + """ + Provider configuration abstraction class + """ + OPENVPN_ALLOWED_KEYS = ("auth", "cipher", "tls-cipher") + OPENVPN_CIPHERS_REGEX = re.compile("[A-Z0-9\-]+") + + def __init__(self): + BaseConfig.__init__(self) + + def _get_spec(self): + """ + Returns the spec object for the specific configuration + """ + return eipservice_config_spec + + def get_clusters(self): + # TODO: create an abstraction for clusters + return self._safe_get_value("clusters") + + def get_gateways(self): + # TODO: create an abstraction for gateways + return self._safe_get_value("gateways") + + def get_locations(self): + ''' + Returns a list of locations + + :rtype: dict + ''' + return self._safe_get_value("locations") + + def get_openvpn_configuration(self): + """ + Returns a dictionary containing the openvpn configuration + parameters. + + These are sanitized with alphanumeric whitelist. + + :returns: openvpn configuration dict + :rtype: C{dict} + """ + ovpncfg = self._safe_get_value("openvpn_configuration") + config = {} + for key, value in ovpncfg.items(): + if key in self.OPENVPN_ALLOWED_KEYS and value is not None: + sanitized_val = self.OPENVPN_CIPHERS_REGEX.findall(value) + if len(sanitized_val) != 0: + _val = sanitized_val[0] + config[str(key)] = str(_val) + return config + + def get_serial(self): + return self._safe_get_value("serial") + + def get_version(self): + return self._safe_get_value("version") + + def get_gateway_ip(self, index=0): + """ + Returns the ip of the gateway. + + :rtype: An IPv4Address or IPv6Address object. + """ + gateways = self.get_gateways() + leap_assert(len(gateways) > 0, "We don't have any gateway!") + if index > len(gateways): + index = 0 + logger.warning("Provided an unknown gateway index %s, " + + "defaulting to 0") + ip_addr_str = gateways[index]["ip_address"] + + try: + ipaddr.IPAddress(ip_addr_str) + return ip_addr_str + except ValueError: + logger.error("Invalid ip address in config: %s" % (ip_addr_str,)) + return None + + def get_client_cert_path(self, + providerconfig=None, + about_to_download=False): + """ + Returns the path to the certificate used by openvpn + """ + + leap_assert(providerconfig, "We need a provider") + leap_assert_type(providerconfig, ProviderConfig) + + cert_path = os.path.join(self.get_path_prefix(), + "leap", + "providers", + providerconfig.get_domain(), + "keys", + "client", + "openvpn.pem") + + if not about_to_download: + leap_assert(os.path.exists(cert_path), + "You need to download the certificate first") + logger.debug("Using OpenVPN cert %s" % (cert_path,)) + + return cert_path + + +if __name__ == "__main__": + logger = logging.getLogger(name='leap') + logger.setLevel(logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '%(asctime)s ' + '- %(name)s - %(levelname)s - %(message)s') + console.setFormatter(formatter) + logger.addHandler(console) + + eipconfig = EIPConfig() + + try: + eipconfig.get_clusters() + except Exception as e: + assert isinstance(e, AssertionError), "Expected an assert" + print "Safe value getting is working" + + if eipconfig.load("leap/providers/bitmask.net/eip-service.json"): + print eipconfig.get_clusters() + print eipconfig.get_gateways() + print eipconfig.get_locations() + print eipconfig.get_openvpn_configuration() + print eipconfig.get_serial() + print eipconfig.get_version() diff --git a/src/leap/services/eip/eipspec.py b/src/leap/services/eip/eipspec.py new file mode 100644 index 00000000..94ba674f --- /dev/null +++ b/src/leap/services/eip/eipspec.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# eipspec.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +eipservice_config_spec = { + 'description': 'sample eip service config', + 'type': 'object', + 'properties': { + 'serial': { + 'type': int, + 'default': 1, + 'required': ["True"] + }, + 'version': { + 'type': int, + 'default': 1, + 'required': ["True"] + }, + 'clusters': { + 'type': list, + 'default': [ + {"label": { + "en": "Location Unknown"}, + "name": "location_unknown"}] + }, + 'gateways': { + 'type': list, + 'default': [ + {"capabilities": { + "adblock": True, + "filter_dns": True, + "ports": ["80", "53", "443", "1194"], + "protocols": ["udp", "tcp"], + "transport": ["openvpn"], + "user_ips": False}, + "cluster": "location_unknown", + "host": "location.example.org", + "ip_address": "127.0.0.1"}] + }, + 'locations': { + 'type': dict, + 'default': {} + }, + 'openvpn_configuration': { + 'type': dict, + 'default': { + "auth": None, + "cipher": None, + "tls-cipher": None} + } + } +} diff --git a/src/leap/services/eip/providerbootstrapper.py b/src/leap/services/eip/providerbootstrapper.py new file mode 100644 index 00000000..754d0643 --- /dev/null +++ b/src/leap/services/eip/providerbootstrapper.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# providerbootstrapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +Provider bootstrapping +""" +import logging +import socket +import os + +import requests + +from PySide import QtCore + +from leap.common.certs import get_digest +from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p +from leap.common.check import leap_assert, leap_assert_type +from leap.config.providerconfig import ProviderConfig +from leap.util.request_helpers import get_content +from leap.services.abstractbootstrapper import AbstractBootstrapper +from leap.provider.supportedapis import SupportedAPIs + + +logger = logging.getLogger(__name__) + + +class UnsupportedProviderAPI(Exception): + """ + Raised when attempting to use a provider with an incompatible API. + """ + pass + + +class ProviderBootstrapper(AbstractBootstrapper): + """ + Given a provider URL performs a series of checks and emits signals + after they are passed. + If a check fails, the subsequent checks are not executed + """ + + # All dicts returned are of the form + # {"passed": bool, "error": str} + name_resolution = QtCore.Signal(dict) + https_connection = QtCore.Signal(dict) + download_provider_info = QtCore.Signal(dict) + + download_ca_cert = QtCore.Signal(dict) + check_ca_fingerprint = QtCore.Signal(dict) + check_api_certificate = QtCore.Signal(dict) + + def __init__(self, bypass_checks=False): + """ + Constructor for provider bootstrapper object + + :param bypass_checks: Set to true if the app should bypass + first round of checks for CA certificates at bootstrap + :type bypass_checks: bool + """ + AbstractBootstrapper.__init__(self, bypass_checks) + + self._domain = None + self._provider_config = None + self._download_if_needed = False + + def _check_name_resolution(self): + """ + Checks that the name resolution for the provider name works + """ + leap_assert(self._domain, "Cannot check DNS without a domain") + + logger.debug("Checking name resolution for %s" % (self._domain)) + + # We don't skip this check, since it's basic for the whole + # system to work + socket.gethostbyname(self._domain) + + def _check_https(self, *args): + """ + Checks that https is working and that the provided certificate + checks out + """ + + leap_assert(self._domain, "Cannot check HTTPS without a domain") + + logger.debug("Checking https for %s" % (self._domain)) + + # We don't skip this check, since it's basic for the whole + # system to work + + try: + res = self._session.get("https://%s" % (self._domain,), + verify=not self._bypass_checks) + res.raise_for_status() + except requests.exceptions.SSLError: + self._err_msg = self.tr("Provider certificate could " + "not be verified") + raise + except Exception: + self._err_msg = self.tr("Provider does not support HTTPS") + raise + + def _download_provider_info(self, *args): + """ + Downloads the provider.json defition + """ + leap_assert(self._domain, + "Cannot download provider info without a domain") + + logger.debug("Downloading provider info for %s" % (self._domain)) + + headers = {} + mtime = get_mtime(os.path.join(ProviderConfig() + .get_path_prefix(), + "leap", + "providers", + self._domain, + "provider.json")) + if self._download_if_needed and mtime: + headers['if-modified-since'] = mtime + + res = self._session.get("https://%s/%s" % (self._domain, + "provider.json"), + headers=headers, + verify=not self._bypass_checks) + res.raise_for_status() + + # Not modified + if res.status_code == 304: + logger.debug("Provider definition has not been modified") + else: + provider_definition, mtime = get_content(res) + + provider_config = ProviderConfig() + provider_config.load(data=provider_definition, mtime=mtime) + provider_config.save(["leap", + "providers", + self._domain, + "provider.json"]) + + api_version = provider_config.get_api_version() + if SupportedAPIs.supports(api_version): + logger.debug("Provider definition has been modified") + else: + api_supported = ', '.join(SupportedAPIs.SUPPORTED_APIS) + error = ('Unsupported provider API version. ' + 'Supported versions are: {}. ' + 'Found: {}.').format(api_supported, api_version) + + logger.error(error) + raise UnsupportedProviderAPI(error) + + def run_provider_select_checks(self, domain, download_if_needed=False): + """ + Populates the check queue. + + :param domain: domain to check + :type domain: str + + :param download_if_needed: if True, makes the checks do not + overwrite already downloaded data + :type download_if_needed: bool + """ + leap_assert(domain and len(domain) > 0, "We need a domain!") + + self._domain = domain + self._download_if_needed = download_if_needed + + cb_chain = [ + (self._check_name_resolution, self.name_resolution), + (self._check_https, self.https_connection), + (self._download_provider_info, self.download_provider_info) + ] + + return self.addCallbackChain(cb_chain) + + def _should_proceed_cert(self): + """ + Returns False if the certificate already exists for the given + provider. True otherwise + + :rtype: bool + """ + leap_assert(self._provider_config, "We need a provider config!") + + if not self._download_if_needed: + return True + + return not os.path.exists(self._provider_config + .get_ca_cert_path(about_to_download=True)) + + def _download_ca_cert(self, *args): + """ + Downloads the CA cert that is going to be used for the api URL + """ + + leap_assert(self._provider_config, "Cannot download the ca cert " + "without a provider config!") + + logger.debug("Downloading ca cert for %s at %s" % + (self._domain, self._provider_config.get_ca_cert_uri())) + + if not self._should_proceed_cert(): + check_and_fix_urw_only( + self._provider_config + .get_ca_cert_path(about_to_download=True)) + return + + res = self._session.get(self._provider_config.get_ca_cert_uri(), + verify=not self._bypass_checks) + res.raise_for_status() + + cert_path = self._provider_config.get_ca_cert_path( + about_to_download=True) + cert_dir = os.path.dirname(cert_path) + mkdir_p(cert_dir) + with open(cert_path, "w") as f: + f.write(res.content) + + check_and_fix_urw_only(cert_path) + + def _check_ca_fingerprint(self, *args): + """ + Checks the CA cert fingerprint against the one provided in the + json definition + """ + leap_assert(self._provider_config, "Cannot check the ca cert " + "without a provider config!") + + logger.debug("Checking ca fingerprint for %s and cert %s" % + (self._domain, + self._provider_config.get_ca_cert_path())) + + if not self._should_proceed_cert(): + return + + parts = self._provider_config.get_ca_cert_fingerprint().split(":") + leap_assert(len(parts) == 2, "Wrong fingerprint format") + + method = parts[0].strip() + fingerprint = parts[1].strip() + cert_data = None + with open(self._provider_config.get_ca_cert_path()) as f: + cert_data = f.read() + + leap_assert(len(cert_data) > 0, "Could not read certificate data") + digest = get_digest(cert_data, method) + leap_assert(digest == fingerprint, + "Downloaded certificate has a different fingerprint!") + + def _check_api_certificate(self, *args): + """ + Tries to make an API call with the downloaded cert and checks + if it validates against it + """ + leap_assert(self._provider_config, "Cannot check the ca cert " + "without a provider config!") + + logger.debug("Checking api certificate for %s and cert %s" % + (self._provider_config.get_api_uri(), + self._provider_config.get_ca_cert_path())) + + if not self._should_proceed_cert(): + return + + test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(), + self._provider_config.get_api_version()) + res = self._session.get(test_uri, + verify=self._provider_config + .get_ca_cert_path()) + res.raise_for_status() + + def run_provider_setup_checks(self, + provider_config, + download_if_needed=False): + """ + Starts the checks needed for a new provider setup. + + :param provider_config: Provider configuration + :type provider_config: ProviderConfig + + :param download_if_needed: if True, makes the checks do not + overwrite already downloaded data. + :type download_if_needed: bool + """ + leap_assert(provider_config, "We need a provider config!") + leap_assert_type(provider_config, ProviderConfig) + + self._provider_config = provider_config + self._download_if_needed = download_if_needed + + cb_chain = [ + (self._download_ca_cert, self.download_ca_cert), + (self._check_ca_fingerprint, self.check_ca_fingerprint), + (self._check_api_certificate, self.check_api_certificate) + ] + + return self.addCallbackChain(cb_chain) diff --git a/src/leap/services/eip/tests/__init__.py b/src/leap/services/eip/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/services/eip/tests/__init__.py diff --git a/src/leap/services/eip/tests/test_eipbootstrapper.py b/src/leap/services/eip/tests/test_eipbootstrapper.py new file mode 100644 index 00000000..f2331eca --- /dev/null +++ b/src/leap/services/eip/tests/test_eipbootstrapper.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +# test_eipbootstrapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Tests for the EIP Boostrapper checks + +These will be whitebox tests since we want to make sure the private +implementation is checking what we expect. +""" + +import os +import mock +import tempfile +import time +try: + import unittest2 as unittest +except ImportError: + import unittest + +from nose.twistedtools import deferred, reactor +from twisted.internet import threads +from requests.models import Response + +from leap.common.testing.basetest import BaseLeapTest +from leap.services.eip.eipbootstrapper import EIPBootstrapper +from leap.services.eip.eipconfig import EIPConfig +from leap.config.providerconfig import ProviderConfig +from leap.crypto.tests import fake_provider +from leap.common.files import mkdir_p +from leap.crypto.srpauth import SRPAuth + + +class EIPBootstrapperActiveTest(BaseLeapTest): + @classmethod + def setUpClass(cls): + BaseLeapTest.setUpClass() + factory = fake_provider.get_provider_factory() + http = reactor.listenTCP(0, factory) + https = reactor.listenSSL( + 0, factory, + fake_provider.OpenSSLServerContextFactory()) + get_port = lambda p: p.getHost().port + cls.http_port = get_port(http) + cls.https_port = get_port(https) + + def setUp(self): + self.eb = EIPBootstrapper() + self.old_pp = EIPConfig.get_path_prefix + self.old_save = EIPConfig.save + self.old_load = EIPConfig.load + self.old_si = SRPAuth.get_session_id + + def tearDown(self): + EIPConfig.get_path_prefix = self.old_pp + EIPConfig.save = self.old_save + EIPConfig.load = self.old_load + SRPAuth.get_session_id = self.old_si + + def _download_config_test_template(self, ifneeded, new): + """ + All download config tests have the same structure, so this is + a parametrized test for that. + + :param ifneeded: sets _download_if_needed + :type ifneeded: bool + :param new: if True uses time.time() as mtime for the mocked + eip-service file, otherwise it uses 100 (a really + old mtime) + :type new: float or int (will be coersed) + """ + pc = ProviderConfig() + pc.get_domain = mock.MagicMock( + return_value="localhost:%s" % (self.https_port)) + self.eb._provider_config = pc + + pc.get_api_uri = mock.MagicMock( + return_value="https://%s" % (pc.get_domain())) + pc.get_api_version = mock.MagicMock(return_value="1") + + # This is to ignore https checking, since it's not the point + # of this test + pc.get_ca_cert_path = mock.MagicMock(return_value=False) + + path_prefix = tempfile.mkdtemp() + EIPConfig.get_path_prefix = mock.MagicMock(return_value=path_prefix) + EIPConfig.save = mock.MagicMock() + EIPConfig.load = mock.MagicMock() + + self.eb._download_if_needed = ifneeded + + provider_dir = os.path.join(EIPConfig.get_path_prefix(), + "leap", + "providers", + pc.get_domain()) + mkdir_p(provider_dir) + eip_config_path = os.path.join(provider_dir, + "eip-service.json") + + with open(eip_config_path, "w") as ec: + ec.write("A") + + # set mtime to something really new + if new: + os.utime(eip_config_path, (-1, time.time())) + else: + os.utime(eip_config_path, (-1, 100)) + + @deferred() + def test_download_config_not_modified(self): + self._download_config_test_template(True, True) + + d = threads.deferToThread(self.eb._download_config) + + def check(*args): + self.assertFalse(self.eb._eip_config.save.called) + d.addCallback(check) + return d + + @deferred() + def test_download_config_modified(self): + self._download_config_test_template(True, False) + + d = threads.deferToThread(self.eb._download_config) + + def check(*args): + self.assertTrue(self.eb._eip_config.save.called) + d.addCallback(check) + return d + + @deferred() + def test_download_config_ignores_mtime(self): + self._download_config_test_template(False, True) + + d = threads.deferToThread(self.eb._download_config) + + def check(*args): + self.eb._eip_config.save.assert_called_once_with( + ["leap", + "providers", + self.eb._provider_config.get_domain(), + "eip-service.json"]) + d.addCallback(check) + return d + + def _download_certificate_test_template(self, ifneeded, createcert): + """ + All download client certificate tests have the same structure, + so this is a parametrized test for that. + + :param ifneeded: sets _download_if_needed + :type ifneeded: bool + :param createcert: if True it creates a dummy file to play the + part of a downloaded certificate + :type createcert: bool + + :returns: the temp eip cert path and the dummy cert contents + :rtype: tuple of str, str + """ + pc = ProviderConfig() + ec = EIPConfig() + self.eb._provider_config = pc + self.eb._eip_config = ec + + pc.get_domain = mock.MagicMock( + return_value="localhost:%s" % (self.https_port)) + pc.get_api_uri = mock.MagicMock( + return_value="https://%s" % (pc.get_domain())) + pc.get_api_version = mock.MagicMock(return_value="1") + pc.get_ca_cert_path = mock.MagicMock(return_value=False) + + path_prefix = tempfile.mkdtemp() + EIPConfig.get_path_prefix = mock.MagicMock(return_value=path_prefix) + EIPConfig.save = mock.MagicMock() + EIPConfig.load = mock.MagicMock() + + self.eb._download_if_needed = ifneeded + + provider_dir = os.path.join(EIPConfig.get_path_prefix(), + "leap", + "providers", + "somedomain") + mkdir_p(provider_dir) + eip_cert_path = os.path.join(provider_dir, + "cert") + + ec.get_client_cert_path = mock.MagicMock( + return_value=eip_cert_path) + + cert_content = "A" + if createcert: + with open(eip_cert_path, "w") as ec: + ec.write(cert_content) + + return eip_cert_path, cert_content + + def test_download_client_certificate_not_modified(self): + cert_path, old_cert_content = self._download_certificate_test_template( + True, True) + + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=False): + self.eb._download_client_certificates() + with open(cert_path, "r") as c: + self.assertEqual(c.read(), old_cert_content) + + @deferred() + def test_download_client_certificate_old_cert(self): + cert_path, old_cert_content = self._download_certificate_test_template( + True, True) + + def wrapper(*args): + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=True): + with mock.patch('leap.common.certs.is_valid_pemfile', + new_callable=mock.MagicMock, + return_value=True): + self.eb._download_client_certificates() + + def check(*args): + with open(cert_path, "r") as c: + self.assertNotEqual(c.read(), old_cert_content) + d = threads.deferToThread(wrapper) + d.addCallback(check) + + return d + + @deferred() + def test_download_client_certificate_no_cert(self): + cert_path, _ = self._download_certificate_test_template( + True, False) + + def wrapper(*args): + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=False): + with mock.patch('leap.common.certs.is_valid_pemfile', + new_callable=mock.MagicMock, + return_value=True): + self.eb._download_client_certificates() + + def check(*args): + self.assertTrue(os.path.exists(cert_path)) + d = threads.deferToThread(wrapper) + d.addCallback(check) + + return d + + @deferred() + def test_download_client_certificate_force_not_valid(self): + cert_path, old_cert_content = self._download_certificate_test_template( + True, True) + + def wrapper(*args): + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=True): + with mock.patch('leap.common.certs.is_valid_pemfile', + new_callable=mock.MagicMock, + return_value=True): + self.eb._download_client_certificates() + + def check(*args): + with open(cert_path, "r") as c: + self.assertNotEqual(c.read(), old_cert_content) + d = threads.deferToThread(wrapper) + d.addCallback(check) + + return d + + @deferred() + def test_download_client_certificate_invalid_download(self): + cert_path, _ = self._download_certificate_test_template( + False, False) + + def wrapper(*args): + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=True): + with mock.patch('leap.common.certs.is_valid_pemfile', + new_callable=mock.MagicMock, + return_value=False): + with self.assertRaises(Exception): + self.eb._download_client_certificates() + d = threads.deferToThread(wrapper) + + return d + + @deferred() + def test_download_client_certificate_uses_session_id(self): + _, _ = self._download_certificate_test_template( + False, False) + + SRPAuth.get_session_id = mock.MagicMock(return_value="1") + + def check_cookie(*args, **kwargs): + cookies = kwargs.get("cookies", None) + self.assertEqual(cookies, {'_session_id': '1'}) + return Response() + + def wrapper(*args): + with mock.patch('leap.common.certs.should_redownload', + new_callable=mock.MagicMock, + return_value=False): + with mock.patch('leap.common.certs.is_valid_pemfile', + new_callable=mock.MagicMock, + return_value=True): + with mock.patch('requests.sessions.Session.get', + new_callable=mock.MagicMock, + side_effect=check_cookie): + with mock.patch('requests.models.Response.content', + new_callable=mock.PropertyMock, + return_value="A"): + self.eb._download_client_certificates() + + d = threads.deferToThread(wrapper) + + return d + + @deferred() + def test_run_eip_setup_checks(self): + self.eb._download_config = mock.MagicMock() + self.eb._download_client_certificates = mock.MagicMock() + + d = self.eb.run_eip_setup_checks(ProviderConfig()) + + def check(*args): + self.eb._download_config.assert_called_once_with() + self.eb._download_client_certificates.assert_called_once_with(None) + d.addCallback(check) + return d diff --git a/src/leap/services/eip/tests/test_eipconfig.py b/src/leap/services/eip/tests/test_eipconfig.py new file mode 100644 index 00000000..8b746b78 --- /dev/null +++ b/src/leap/services/eip/tests/test_eipconfig.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +# test_eipconfig.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for eipconfig +""" +import copy +import json +import os +import unittest + +from leap.common.testing.basetest import BaseLeapTest +from leap.services.eip.eipconfig import EIPConfig +from leap.config.providerconfig import ProviderConfig + +from mock import Mock + + +sample_config = { + "gateways": [ + { + "capabilities": { + "adblock": False, + "filter_dns": True, + "limited": True, + "ports": [ + "1194", + "443", + "53", + "80"], + "protocols": [ + "tcp", + "udp"], + "transport": ["openvpn"], + "user_ips": False}, + "host": "host.dev.example.org", + "ip_address": "11.22.33.44", + "location": "cyberspace" + }, { + "capabilities": { + "adblock": False, + "filter_dns": True, + "limited": True, + "ports": [ + "1194", + "443", + "53", + "80"], + "protocols": [ + "tcp", + "udp"], + "transport": ["openvpn"], + "user_ips": False}, + "host": "host2.dev.example.org", + "ip_address": "22.33.44.55", + "location": "cyberspace" + } + ], + "locations": { + "ankara": { + "country_code": "XX", + "hemisphere": "S", + "name": "Antarctica", + "timezone": "+2" + }, + "cyberspace": { + "country_code": "XX", + "hemisphere": "X", + "name": "outer space", + "timezone": "" + } + }, + "openvpn_configuration": { + "auth": "SHA1", + "cipher": "AES-128-CBC", + "tls-cipher": "DHE-RSA-AES128-SHA" + }, + "serial": 1, + "version": 1 +} + + +class EIPConfigTest(BaseLeapTest): + + __name__ = "eip_config_tests" + + maxDiff = None + + def setUp(self): + self._old_ospath_exists = os.path.exists + + def tearDown(self): + os.path.exists = self._old_ospath_exists + + def _write_config(self, data): + """ + Helper to write some data to a temp config file. + + :param data: data to be used to save in the config file. + :data type: dict (valid json) + """ + self.configfile = os.path.join(self.tempdir, "eipconfig.json") + conf = open(self.configfile, "w") + conf.write(json.dumps(data)) + conf.close() + + def _get_eipconfig(self, fromfile=True, data=sample_config): + """ + Helper that returns an EIPConfig object using the data parameter + or a sample data. + + :param fromfile: sets if we should use a file or a string + :fromfile type: bool + :param data: sets the data to be used to load in the EIPConfig object + :data type: dict (valid json) + :rtype: EIPConfig + """ + config = EIPConfig() + + loaded = False + if fromfile: + self._write_config(data) + loaded = config.load(self.configfile, relative=False) + else: + json_string = json.dumps(data) + loaded = config.load(data=json_string) + + if not loaded: + return None + + return config + + def test_loads_from_file(self): + config = self._get_eipconfig() + self.assertIsNotNone(config) + + def test_loads_from_data(self): + config = self._get_eipconfig(fromfile=False) + self.assertIsNotNone(config) + + def test_load_valid_config_from_file(self): + config = self._get_eipconfig() + self.assertIsNotNone(config) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + sample_ip = sample_config["gateways"][0]["ip_address"] + self.assertEqual( + config.get_gateway_ip(), + sample_ip) + self.assertEqual(config.get_version(), sample_config["version"]) + self.assertEqual(config.get_serial(), sample_config["serial"]) + self.assertEqual(config.get_gateways(), sample_config["gateways"]) + self.assertEqual(config.get_locations(), sample_config["locations"]) + self.assertEqual(config.get_clusters(), None) + + def test_load_valid_config_from_data(self): + config = self._get_eipconfig(fromfile=False) + self.assertIsNotNone(config) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + sample_ip = sample_config["gateways"][0]["ip_address"] + self.assertEqual( + config.get_gateway_ip(), + sample_ip) + + self.assertEqual(config.get_version(), sample_config["version"]) + self.assertEqual(config.get_serial(), sample_config["serial"]) + self.assertEqual(config.get_gateways(), sample_config["gateways"]) + self.assertEqual(config.get_locations(), sample_config["locations"]) + self.assertEqual(config.get_clusters(), None) + + def test_sanitize_extra_parameters(self): + data = copy.deepcopy(sample_config) + data['openvpn_configuration']["extra_param"] = "FOO" + config = self._get_eipconfig(data=data) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + def test_sanitize_non_allowed_chars(self): + data = copy.deepcopy(sample_config) + data['openvpn_configuration']["auth"] = "SHA1;" + config = self._get_eipconfig(data=data) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + data = copy.deepcopy(sample_config) + data['openvpn_configuration']["auth"] = "SHA1>`&|" + config = self._get_eipconfig(data=data) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + def test_sanitize_lowercase(self): + data = copy.deepcopy(sample_config) + data['openvpn_configuration']["auth"] = "shaSHA1" + config = self._get_eipconfig(data=data) + + self.assertEqual( + config.get_openvpn_configuration(), + sample_config["openvpn_configuration"]) + + def test_all_characters_invalid(self): + data = copy.deepcopy(sample_config) + data['openvpn_configuration']["auth"] = "sha&*!@#;" + config = self._get_eipconfig(data=data) + + self.assertEqual( + config.get_openvpn_configuration(), + {'cipher': 'AES-128-CBC', + 'tls-cipher': 'DHE-RSA-AES128-SHA'}) + + def test_sanitize_bad_ip(self): + data = copy.deepcopy(sample_config) + data['gateways'][0]["ip_address"] = "11.22.33.44;" + config = self._get_eipconfig(data=data) + + self.assertEqual(config.get_gateway_ip(), None) + + data = copy.deepcopy(sample_config) + data['gateways'][0]["ip_address"] = "11.22.33.44`" + config = self._get_eipconfig(data=data) + + self.assertEqual(config.get_gateway_ip(), None) + + def test_default_gateway_on_unknown_index(self): + config = self._get_eipconfig() + sample_ip = sample_config["gateways"][0]["ip_address"] + self.assertEqual(config.get_gateway_ip(999), sample_ip) + + def test_get_gateway_by_index(self): + config = self._get_eipconfig() + sample_ip_0 = sample_config["gateways"][0]["ip_address"] + sample_ip_1 = sample_config["gateways"][1]["ip_address"] + self.assertEqual(config.get_gateway_ip(0), sample_ip_0) + self.assertEqual(config.get_gateway_ip(1), sample_ip_1) + + def test_get_client_cert_path_as_expected(self): + config = self._get_eipconfig() + config.get_path_prefix = Mock(return_value='test') + + provider_config = ProviderConfig() + + # mock 'get_domain' so we don't need to load a config + provider_domain = 'test.provider.com' + provider_config.get_domain = Mock(return_value=provider_domain) + + expected_path = os.path.join('test', 'leap', 'providers', + provider_domain, 'keys', 'client', + 'openvpn.pem') + + # mock 'os.path.exists' so we don't get an error for unexisting file + os.path.exists = Mock(return_value=True) + cert_path = config.get_client_cert_path(provider_config) + + self.assertEqual(cert_path, expected_path) + + def test_get_client_cert_path_about_to_download(self): + config = self._get_eipconfig() + config.get_path_prefix = Mock(return_value='test') + + provider_config = ProviderConfig() + + # mock 'get_domain' so we don't need to load a config + provider_domain = 'test.provider.com' + provider_config.get_domain = Mock(return_value=provider_domain) + + expected_path = os.path.join('test', 'leap', 'providers', + provider_domain, 'keys', 'client', + 'openvpn.pem') + + cert_path = config.get_client_cert_path( + provider_config, about_to_download=True) + + self.assertEqual(cert_path, expected_path) + + def test_get_client_cert_path_fails(self): + config = self._get_eipconfig() + provider_config = ProviderConfig() + + # mock 'get_domain' so we don't need to load a config + provider_domain = 'test.provider.com' + provider_config.get_domain = Mock(return_value=provider_domain) + + with self.assertRaises(AssertionError): + config.get_client_cert_path(provider_config) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/services/eip/tests/test_providerbootstrapper.py b/src/leap/services/eip/tests/test_providerbootstrapper.py new file mode 100644 index 00000000..cd740793 --- /dev/null +++ b/src/leap/services/eip/tests/test_providerbootstrapper.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +# test_providerbootstrapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Tests for the Provider Boostrapper checks + +These will be whitebox tests since we want to make sure the private +implementation is checking what we expect. +""" + +import os +import mock +import socket +import stat +import tempfile +import time +import requests +try: + import unittest2 as unittest +except ImportError: + import unittest + +from nose.twistedtools import deferred, reactor +from twisted.internet import threads +from requests.models import Response + +from leap.common.testing.https_server import where +from leap.common.testing.basetest import BaseLeapTest +from leap.services.eip.providerbootstrapper import ProviderBootstrapper +from leap.services.eip.providerbootstrapper import UnsupportedProviderAPI +from leap.provider.supportedapis import SupportedAPIs +from leap.config.providerconfig import ProviderConfig +from leap.crypto.tests import fake_provider +from leap.common.files import mkdir_p + + +class ProviderBootstrapperTest(BaseLeapTest): + def setUp(self): + self.pb = ProviderBootstrapper() + + def tearDown(self): + pass + + def test_name_resolution_check(self): + # Something highly likely to success + self.pb._domain = "google.com" + self.pb._check_name_resolution() + # Something highly likely to fail + self.pb._domain = "uquhqweuihowquie.abc.def" + with self.assertRaises(socket.gaierror): + self.pb._check_name_resolution() + + @deferred() + def test_run_provider_select_checks(self): + self.pb._check_name_resolution = mock.MagicMock() + self.pb._check_https = mock.MagicMock() + self.pb._download_provider_info = mock.MagicMock() + + d = self.pb.run_provider_select_checks("somedomain") + + def check(*args): + self.pb._check_name_resolution.assert_called_once_with() + self.pb._check_https.assert_called_once_with(None) + self.pb._download_provider_info.assert_called_once_with(None) + d.addCallback(check) + return d + + @deferred() + def test_run_provider_setup_checks(self): + self.pb._download_ca_cert = mock.MagicMock() + self.pb._check_ca_fingerprint = mock.MagicMock() + self.pb._check_api_certificate = mock.MagicMock() + + d = self.pb.run_provider_setup_checks(ProviderConfig()) + + def check(*args): + self.pb._download_ca_cert.assert_called_once_with() + self.pb._check_ca_fingerprint.assert_called_once_with(None) + self.pb._check_api_certificate.assert_called_once_with(None) + d.addCallback(check) + return d + + def test_should_proceed_cert(self): + self.pb._provider_config = mock.Mock() + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=where("cacert.pem")) + + self.pb._download_if_needed = False + self.assertTrue(self.pb._should_proceed_cert()) + + self.pb._download_if_needed = True + self.assertFalse(self.pb._should_proceed_cert()) + + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=where("somefilethatdoesntexist.pem")) + self.assertTrue(self.pb._should_proceed_cert()) + + def _check_download_ca_cert(self, should_proceed): + """ + Helper to check different paths easily for the download ca + cert check + + :param should_proceed: sets the _should_proceed_cert in the + provider bootstrapper being tested + :type should_proceed: bool + + :returns: The contents of the certificate, the expected + content depending on should_proceed, and the mode of + the file to be checked by the caller + :rtype: tuple of str, str, int + """ + old_content = "NOT THE NEW CERT" + new_content = "NEW CERT" + new_cert_path = os.path.join(tempfile.mkdtemp(), + "mynewcert.pem") + + with open(new_cert_path, "w") as c: + c.write(old_content) + + self.pb._provider_config = mock.Mock() + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=new_cert_path) + self.pb._domain = "somedomain" + + self.pb._should_proceed_cert = mock.MagicMock( + return_value=should_proceed) + + read = None + content_to_check = None + mode = None + + with mock.patch('requests.models.Response.content', + new_callable=mock.PropertyMock) as \ + content: + content.return_value = new_content + response_obj = Response() + response_obj.raise_for_status = mock.MagicMock() + + self.pb._session.get = mock.MagicMock(return_value=response_obj) + self.pb._download_ca_cert() + with open(new_cert_path, "r") as nc: + read = nc.read() + if should_proceed: + content_to_check = new_content + else: + content_to_check = old_content + mode = stat.S_IMODE(os.stat(new_cert_path).st_mode) + + os.unlink(new_cert_path) + return read, content_to_check, mode + + def test_download_ca_cert_no_saving(self): + read, expected_read, mode = self._check_download_ca_cert(False) + self.assertEqual(read, expected_read) + self.assertEqual(mode, int("600", 8)) + + def test_download_ca_cert_saving(self): + read, expected_read, mode = self._check_download_ca_cert(True) + self.assertEqual(read, expected_read) + self.assertEqual(mode, int("600", 8)) + + def test_check_ca_fingerprint_skips(self): + self.pb._provider_config = mock.Mock() + self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock( + return_value="") + self.pb._domain = "somedomain" + + self.pb._should_proceed_cert = mock.MagicMock(return_value=False) + + self.pb._check_ca_fingerprint() + self.assertFalse(self.pb._provider_config. + get_ca_cert_fingerprint.called) + + def test_check_ca_cert_fingerprint_raises_bad_format(self): + self.pb._provider_config = mock.Mock() + self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock( + return_value="wrongfprformat!!") + self.pb._domain = "somedomain" + + self.pb._should_proceed_cert = mock.MagicMock(return_value=True) + + with self.assertRaises(AssertionError): + self.pb._check_ca_fingerprint() + + # This two hashes different in the last byte, but that's good enough + # for the tests + KNOWN_BAD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034efe" \ + "7dd1b910062ca323eb4da5c7f" + KNOWN_GOOD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034ef" \ + "e7dd1b910062ca323eb4da5c7e" + KNOWN_GOOD_CERT = """ +-----BEGIN CERTIFICATE----- +MIIFbzCCA1egAwIBAgIBATANBgkqhkiG9w0BAQ0FADBKMRgwFgYDVQQDDA9CaXRt +YXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNVBAsME2h0dHBzOi8v +Yml0bWFzay5uZXQwHhcNMTIxMTA2MDAwMDAwWhcNMjIxMTA2MDAwMDAwWjBKMRgw +FgYDVQQDDA9CaXRtYXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNV +BAsME2h0dHBzOi8vYml0bWFzay5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAw +ggIKAoICAQC1eV4YvayaU+maJbWrD4OHo3d7S1BtDlcvkIRS1Fw3iYDjsyDkZxai +dHp4EUasfNQ+EVtXUvtk6170EmLco6Elg8SJBQ27trE6nielPRPCfX3fQzETRfvB +7tNvGw4Jn2YKiYoMD79kkjgyZjkJ2r/bEHUSevmR09BRp86syHZerdNGpXYhcQ84 +CA1+V+603GFIHnrP+uQDdssW93rgDNYu+exT+Wj6STfnUkugyjmPRPjL7wh0tzy+ +znCeLl4xiV3g9sjPnc7r2EQKd5uaTe3j71sDPF92KRk0SSUndREz+B1+Dbe/RGk4 +MEqGFuOzrtsgEhPIX0hplhb0Tgz/rtug+yTT7oJjBa3u20AAOQ38/M99EfdeJvc4 +lPFF1XBBLh6X9UKF72an2NuANiX6XPySnJgZ7nZ09RiYZqVwu/qt3DfvLfhboq+0 +bQvLUPXrVDr70onv5UDjpmEA/cLmaIqqrduuTkFZOym65/PfAPvpGnt7crQj/Ibl +DEDYZQmP7AS+6zBjoOzNjUGE5r40zWAR1RSi7zliXTu+yfsjXUIhUAWmYR6J3KxB +lfsiHBQ+8dn9kC3YrUexWoOqBiqJOAJzZh5Y1tqgzfh+2nmHSB2dsQRs7rDRRlyy +YMbkpzL9ZsOUO2eTP1mmar6YjCN+rggYjRrX71K2SpBG6b1zZxOG+wIDAQABo2Aw +XjAdBgNVHQ4EFgQUuYGDLL2sswnYpHHvProt1JU+D48wDgYDVR0PAQH/BAQDAgIE +MAwGA1UdEwQFMAMBAf8wHwYDVR0jBBgwFoAUuYGDLL2sswnYpHHvProt1JU+D48w +DQYJKoZIhvcNAQENBQADggIBADeG67vaFcbITGpi51264kHPYPEWaXUa5XYbtmBl +cXYyB6hY5hv/YNuVGJ1gWsDmdeXEyj0j2icGQjYdHRfwhrbEri+h1EZOm1cSBDuY +k/P5+ctHyOXx8IE79DBsZ6IL61UKIaKhqZBfLGYcWu17DVV6+LT+AKtHhOrv3TSj +RnAcKnCbKqXLhUPXpK0eTjPYS2zQGQGIhIy9sQXVXJJJsGrPgMxna1Xw2JikBOCG +htD/JKwt6xBmNwktH0GI/LVtVgSp82Clbn9C4eZN9E5YbVYjLkIEDhpByeC71QhX +EIQ0ZR56bFuJA/CwValBqV/G9gscTPQqd+iETp8yrFpAVHOW+YzSFbxjTEkBte1J +aF0vmbqdMAWLk+LEFPQRptZh0B88igtx6tV5oVd+p5IVRM49poLhuPNJGPvMj99l +mlZ4+AeRUnbOOeAEuvpLJbel4rhwFzmUiGoeTVoPZyMevWcVFq6BMkS+jRR2w0jK +G6b0v5XDHlcFYPOgUrtsOBFJVwbutLvxdk6q37kIFnWCd8L3kmES5q4wjyFK47Co +Ja8zlx64jmMZPg/t3wWqkZgXZ14qnbyG5/lGsj5CwVtfDljrhN0oCWK1FZaUmW3d +69db12/g4f6phldhxiWuGC/W6fCW5kre7nmhshcltqAJJuU47iX+DarBFiIj816e +yV8e +-----END CERTIFICATE----- +""" + + def _prepare_provider_config_with(self, cert_path, cert_hash): + """ + Mocks the provider config to give the cert_path and cert_hash + specified + + :param cert_path: path for the certificate + :type cert_path: str + :param cert_hash: hash for the certificate as it would appear + in the provider config json + :type cert_hash: str + """ + self.pb._provider_config = mock.Mock() + self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock( + return_value=cert_hash) + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=cert_path) + self.pb._domain = "somedomain" + + def test_check_ca_fingerprint_checksout(self): + cert_path = os.path.join(tempfile.mkdtemp(), + "mynewcert.pem") + + with open(cert_path, "w") as c: + c.write(self.KNOWN_GOOD_CERT) + + self._prepare_provider_config_with(cert_path, self.KNOWN_GOOD_HASH) + + self.pb._should_proceed_cert = mock.MagicMock(return_value=True) + + self.pb._check_ca_fingerprint() + + os.unlink(cert_path) + + def test_check_ca_fingerprint_fails(self): + cert_path = os.path.join(tempfile.mkdtemp(), + "mynewcert.pem") + + with open(cert_path, "w") as c: + c.write(self.KNOWN_GOOD_CERT) + + self._prepare_provider_config_with(cert_path, self.KNOWN_BAD_HASH) + + self.pb._should_proceed_cert = mock.MagicMock(return_value=True) + + with self.assertRaises(AssertionError): + self.pb._check_ca_fingerprint() + + os.unlink(cert_path) + + +############################################################################### +# Tests with a fake provider # +############################################################################### + +class ProviderBootstrapperActiveTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + factory = fake_provider.get_provider_factory() + http = reactor.listenTCP(8002, factory) + https = reactor.listenSSL( + 0, factory, + fake_provider.OpenSSLServerContextFactory()) + get_port = lambda p: p.getHost().port + cls.http_port = get_port(http) + cls.https_port = get_port(https) + + def setUp(self): + self.pb = ProviderBootstrapper() + + # At certain points we are going to be replacing these methods + # directly in ProviderConfig to be able to catch calls from + # new ProviderConfig objects inside the methods tested. We + # need to save the old implementation and restore it in + # tearDown so we are sure everything is as expected for each + # test. If we do it inside each specific test, a failure in + # the test will leave the implementation with the mock. + self.old_gpp = ProviderConfig.get_path_prefix + self.old_load = ProviderConfig.load + self.old_save = ProviderConfig.save + self.old_api_version = ProviderConfig.get_api_version + + def tearDown(self): + ProviderConfig.get_path_prefix = self.old_gpp + ProviderConfig.load = self.old_load + ProviderConfig.save = self.old_save + ProviderConfig.get_api_version = self.old_api_version + + def test_check_https_succeeds(self): + # XXX: Need a proper CA signed cert to test this + pass + + @deferred() + def test_check_https_fails(self): + self.pb._domain = "localhost:%s" % (self.https_port,) + + def check(*args): + with self.assertRaises(requests.exceptions.SSLError): + self.pb._check_https() + return threads.deferToThread(check) + + @deferred() + def test_second_check_https_fails(self): + self.pb._domain = "localhost:1234" + + def check(*args): + with self.assertRaises(Exception): + self.pb._check_https() + return threads.deferToThread(check) + + @deferred() + def test_check_https_succeeds_if_danger(self): + self.pb._domain = "localhost:%s" % (self.https_port,) + self.pb._bypass_checks = True + + def check(*args): + self.pb._check_https() + + return threads.deferToThread(check) + + def _setup_provider_config_with(self, api, path_prefix): + """ + Sets up the ProviderConfig with mocks for the path prefix, the + api returned and load/save methods. + It modifies ProviderConfig directly instead of an object + because the object used is created in the method itself and we + cannot control that. + + :param api: API to return + :type api: str + :param path_prefix: path prefix to be used when calculating + paths + :type path_prefix: str + """ + ProviderConfig.get_path_prefix = mock.MagicMock( + return_value=path_prefix) + ProviderConfig.get_api_version = mock.MagicMock( + return_value=api) + ProviderConfig.load = mock.MagicMock() + ProviderConfig.save = mock.MagicMock() + + def _setup_providerbootstrapper(self, ifneeded): + """ + Sets the provider bootstrapper's domain to + localhost:https_port, sets it to bypass https checks and sets + the download if needed based on the ifneeded value. + + :param ifneeded: Value for _download_if_needed + :type ifneeded: bool + """ + self.pb._domain = "localhost:%s" % (self.https_port,) + self.pb._bypass_checks = True + self.pb._download_if_needed = ifneeded + + def _produce_dummy_provider_json(self): + """ + Creates a dummy provider json on disk in order to test + behaviour around it (download if newer online, etc) + + :returns: the provider.json path used + :rtype: str + """ + provider_dir = os.path.join(ProviderConfig() + .get_path_prefix(), + "leap", + "providers", + self.pb._domain) + mkdir_p(provider_dir) + provider_path = os.path.join(provider_dir, + "provider.json") + + with open(provider_path, "w") as p: + p.write("A") + return provider_path + + def test_download_provider_info_not_modified(self): + self._setup_provider_config_with("1", tempfile.mkdtemp()) + self._setup_providerbootstrapper(True) + provider_path = self._produce_dummy_provider_json() + + # set mtime to something really new + os.utime(provider_path, (-1, time.time())) + + self.pb._download_provider_info() + # we check that it doesn't do anything with the provider + # config, because it's new enough + self.assertFalse(ProviderConfig.load.called) + self.assertFalse(ProviderConfig.save.called) + + def test_download_provider_info_modified(self): + self._setup_provider_config_with("1", tempfile.mkdtemp()) + self._setup_providerbootstrapper(True) + provider_path = self._produce_dummy_provider_json() + + # set mtime to something really old + os.utime(provider_path, (-1, 100)) + + self.pb._download_provider_info() + self.assertTrue(ProviderConfig.load.called) + self.assertTrue(ProviderConfig.save.called) + + def test_download_provider_info_unsupported_api_raises(self): + self._setup_provider_config_with("9999999", tempfile.mkdtemp()) + self._setup_providerbootstrapper(False) + self._produce_dummy_provider_json() + + with self.assertRaises(UnsupportedProviderAPI): + self.pb._download_provider_info() + + def test_download_provider_info_unsupported_api(self): + self._setup_provider_config_with(SupportedAPIs.SUPPORTED_APIS[0], + tempfile.mkdtemp()) + self._setup_providerbootstrapper(False) + self._produce_dummy_provider_json() + + self.pb._download_provider_info() + + def test_check_api_certificate_skips(self): + self.pb._provider_config = ProviderConfig() + self.pb._provider_config.get_api_uri = mock.MagicMock( + return_value="api.uri") + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value="/cert/path") + self.pb._session.get = mock.MagicMock(return_value=Response()) + + self.pb._should_proceed_cert = mock.MagicMock(return_value=False) + self.pb._check_api_certificate() + self.assertFalse(self.pb._session.get.called) + + @deferred() + def test_check_api_certificate_fails(self): + self.pb._provider_config = ProviderConfig() + self.pb._provider_config.get_api_uri = mock.MagicMock( + return_value="https://localhost:%s" % (self.https_port,)) + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=os.path.join( + os.path.split(__file__)[0], + "wrongcert.pem")) + self.pb._provider_config.get_api_version = mock.MagicMock( + return_value="1") + + self.pb._should_proceed_cert = mock.MagicMock(return_value=True) + + def check(*args): + with self.assertRaises(requests.exceptions.SSLError): + self.pb._check_api_certificate() + d = threads.deferToThread(check) + return d + + @deferred() + def test_check_api_certificate_succeeds(self): + self.pb._provider_config = ProviderConfig() + self.pb._provider_config.get_api_uri = mock.MagicMock( + return_value="https://localhost:%s" % (self.https_port,)) + self.pb._provider_config.get_ca_cert_path = mock.MagicMock( + return_value=where('cacert.pem')) + self.pb._provider_config.get_api_version = mock.MagicMock( + return_value="1") + + self.pb._should_proceed_cert = mock.MagicMock(return_value=True) + + def check(*args): + self.pb._check_api_certificate() + d = threads.deferToThread(check) + return d diff --git a/src/leap/services/eip/tests/test_vpngatewayselector.py b/src/leap/services/eip/tests/test_vpngatewayselector.py new file mode 100644 index 00000000..c90681d7 --- /dev/null +++ b/src/leap/services/eip/tests/test_vpngatewayselector.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# test_vpngatewayselector.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +tests for vpngatewayselector +""" + +import unittest + +from leap.services.eip.eipconfig import EIPConfig, VPNGatewaySelector +from leap.common.testing.basetest import BaseLeapTest +from mock import Mock + + +sample_gateways = [ + {u'host': u'gateway1.com', + u'ip_address': u'1.2.3.4', + u'location': u'location1'}, + {u'host': u'gateway2.com', + u'ip_address': u'2.3.4.5', + u'location': u'location2'}, + {u'host': u'gateway3.com', + u'ip_address': u'3.4.5.6', + u'location': u'location3'}, + {u'host': u'gateway4.com', + u'ip_address': u'4.5.6.7', + u'location': u'location4'} +] + +sample_gateways_no_location = [ + {u'host': u'gateway1.com', + u'ip_address': u'1.2.3.4'}, + {u'host': u'gateway2.com', + u'ip_address': u'2.3.4.5'}, + {u'host': u'gateway3.com', + u'ip_address': u'3.4.5.6'} +] + +sample_locations = { + u'location1': {u'timezone': u'2'}, + u'location2': {u'timezone': u'-7'}, + u'location3': {u'timezone': u'-4'}, + u'location4': {u'timezone': u'+13'} +} + +# 0 is not used, only for indexing from 1 in tests +ips = (0, u'1.2.3.4', u'2.3.4.5', u'3.4.5.6', u'4.5.6.7') + + +class VPNGatewaySelectorTest(BaseLeapTest): + """ + VPNGatewaySelector's tests. + """ + def setUp(self): + self.eipconfig = EIPConfig() + self.eipconfig.get_gateways = Mock(return_value=sample_gateways) + self.eipconfig.get_locations = Mock(return_value=sample_locations) + + def tearDown(self): + pass + + def test_get_no_gateways(self): + gateway_selector = VPNGatewaySelector(self.eipconfig) + self.eipconfig.get_gateways = Mock(return_value=[]) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, []) + + def test_get_gateway_with_no_locations(self): + gateway_selector = VPNGatewaySelector(self.eipconfig) + self.eipconfig.get_gateways = Mock( + return_value=sample_gateways_no_location) + self.eipconfig.get_locations = Mock(return_value=[]) + gateways = gateway_selector.get_gateways() + gateways_default_order = [ + sample_gateways[0]['ip_address'], + sample_gateways[1]['ip_address'], + sample_gateways[2]['ip_address'] + ] + self.assertEqual(gateways, gateways_default_order) + + def test_correct_order_gmt(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, 0) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[1], ips[3], ips[2], ips[4]]) + + def test_correct_order_gmt_minus_3(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, -3) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[3], ips[2], ips[1], ips[4]]) + + def test_correct_order_gmt_minus_7(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, -7) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[2], ips[3], ips[4], ips[1]]) + + def test_correct_order_gmt_plus_5(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, 5) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[1], ips[4], ips[3], ips[2]]) + + def test_correct_order_gmt_plus_12(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, 12) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]]) + + def test_correct_order_gmt_minus_11(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, -11) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]]) + + def test_correct_order_gmt_plus_14(self): + gateway_selector = VPNGatewaySelector(self.eipconfig, 14) + gateways = gateway_selector.get_gateways() + self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/services/eip/tests/wrongcert.pem b/src/leap/services/eip/tests/wrongcert.pem new file mode 100644 index 00000000..e6cff38a --- /dev/null +++ b/src/leap/services/eip/tests/wrongcert.pem @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFtTCCA52gAwIBAgIJAIWZus5EIXNtMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTMwNjI1MTc0NjExWhcNMTgwNjI1MTc0NjExWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC +CgKCAgEA2ObM7ESjyuxFZYD/Y68qOPQgjgggW+cdXfBpU2p4n7clsrUeMhWdW40Y +77Phzor9VOeqs3ZpHuyLzsYVp/kFDm8tKyo2ah5fJwzL0VCSLYaZkUQQ7GNUmTCk +furaxl8cQx/fg395V7/EngsS9B3/y5iHbctbA4MnH3jaotO5EGeo6hw7/eyCotQ9 +KbBV9GJMcY94FsXBCmUB+XypKklWTLhSaS6Cu4Fo8YLW6WmcnsyEOGS2F7WVf5at +7CBWFQZHaSgIBLmc818/mDYCnYmCVMFn/6Ndx7V2NTlz+HctWrQn0dmIOnCUeCwS +wXq9PnBR1rSx/WxwyF/WpyjOFkcIo7vm72kS70pfrYsXcZD4BQqkXYj3FyKnPt3O +ibLKtCxL8/83wOtErPcYpG6LgFkgAAlHQ9MkUi5dbmjCJtpqQmlZeK1RALdDPiB3 +K1KZimrGsmcE624dJxUIOJJpuwJDy21F8kh5ZAsAtE1prWETrQYNElNFjQxM83rS +ZR1Ql2MPSB4usEZT57+KvpEzlOnAT3elgCg21XrjSFGi14hCEao4g2OEZH5GAwm5 +frf6UlSRZ/g3tLTfI8Hv1prw15W2qO+7q7SBAplTODCRk+Yb0YoA2mMM/QXBUcXs +vKEDLSSxzNIBi3T62l39RB/ml+gPKo87ZMDivex1ZhrcJc3Yu3sCAwEAAaOBpzCB +pDAdBgNVHQ4EFgQUPjE+4pun+8FreIdpoR8v6N7xKtUwdQYDVR0jBG4wbIAUPjE+ +4pun+8FreIdpoR8v6N7xKtWhSaRHMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpT +b21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGSCCQCF +mbrORCFzbTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQCpvCPdtvXJ +muTj379TZuCJs7/l0FhA7AHa1WAlHjsXHaA7N0+3ZWAbdtXDsowal6S+ldgU/kfV +Lq7NrRq+amJWC7SYj6cvVwhrSwSvu01fe/TWuOzHrRv1uTfJ/VXLonVufMDd9opo +bhqYxMaxLdIx6t/MYmZH4Wpiq0yfZuv//M8i7BBl/qvaWbLhg0yVAKRwjFvf59h6 +6tRFCLddELOIhLDQtk8zMbioPEbfAlKdwwP8kYGtDGj6/9/YTd/oTKRdgHuwyup3 +m0L20Y6LddC+tb0WpK5EyrNbCbEqj1L4/U7r6f/FKNA3bx6nfdXbscaMfYonKAKg +1cRrRg45sErmCz0QyTnWzXyvbjR4oQRzyW3kJ1JZudZ+AwOi00J5FYa3NiLuxl1u +gIGKWSrASQWhEdpa1nlCgX7PhdaQgYjEMpQvA0GCA0OF5JDu8en1yZqsOt1hCLIN +lkz/5jKPqrclY5hV99bE3hgCHRmIPNHCZG3wbZv2yJKxJX1YLMmQwAmSh2N7YwGG +yXRvCxQs5ChPHyRairuf/5MZCZnSVb45ppTVuNUijsbflKRUgfj/XvfqQ22f+C9N +Om2dmNvAiS2TOIfuP47CF2OUa5q4plUwmr+nyXQGM0SIoHNCj+MBdFfb3oxxAtI+ +SLhbnzQv5e84Doqz3YF0XW8jyR7q8GFLNA== +-----END CERTIFICATE----- diff --git a/src/leap/services/eip/udstelnet.py b/src/leap/services/eip/udstelnet.py new file mode 100644 index 00000000..e6c82350 --- /dev/null +++ b/src/leap/services/eip/udstelnet.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# udstelnet.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os +import socket +import telnetlib + + +class ConnectionRefusedError(Exception): + pass + + +class MissingSocketError(Exception): + pass + + +class UDSTelnet(telnetlib.Telnet): + """ + A telnet-alike class, that can listen on unix domain sockets + """ + + def open(self, host, port=23, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + """ + Connect to a host. If port is 'unix', it will open a + connection over unix docmain sockets. + + The optional second argument is the port number, which + defaults to the standard telnet port (23). + Don't try to reopen an already connected instance. + """ + self.eof = 0 + self.host = host + self.port = port + self.timeout = timeout + + if self.port == "unix": + # unix sockets spoken + if not os.path.exists(self.host): + raise MissingSocketError() + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self.sock.connect(self.host) + except socket.error: + raise ConnectionRefusedError() + else: + self.sock = socket.create_connection((host, port), timeout) diff --git a/src/leap/services/eip/vpnlaunchers.py b/src/leap/services/eip/vpnlaunchers.py new file mode 100644 index 00000000..570a7893 --- /dev/null +++ b/src/leap/services/eip/vpnlaunchers.py @@ -0,0 +1,774 @@ +# -*- coding: utf-8 -*- +# vpnlaunchers.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +Platform dependant VPN launchers +""" +import commands +import logging +import getpass +import os +import platform +import subprocess +import stat +try: + import grp +except ImportError: + pass # ignore, probably windows + +from abc import ABCMeta, abstractmethod +from functools import partial + +from leap.common.check import leap_assert, leap_assert_type +from leap.common.files import which +from leap.config.providerconfig import ProviderConfig +from leap.services.eip.eipconfig import EIPConfig, VPNGatewaySelector +from leap.util import first + +logger = logging.getLogger(__name__) + + +class VPNLauncherException(Exception): + pass + + +class OpenVPNNotFoundException(VPNLauncherException): + pass + + +class EIPNoPolkitAuthAgentAvailable(VPNLauncherException): + pass + + +class EIPNoPkexecAvailable(VPNLauncherException): + pass + + +class VPNLauncher: + """ + Abstract launcher class + """ + __metaclass__ = ABCMeta + + UPDOWN_FILES = None + OTHER_FILES = None + + @abstractmethod + def get_vpn_command(self, eipconfig=None, providerconfig=None, + socket_host=None, socket_port=None): + """ + Returns the platform dependant vpn launching command + + :param eipconfig: eip configuration object + :type eipconfig: EIPConfig + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + + :return: A VPN command ready to be launched + :rtype: list + """ + return [] + + @abstractmethod + def get_vpn_env(self, providerconfig): + """ + Returns a dictionary with the custom env for the platform. + This is mainly used for setting LD_LIBRARY_PATH to the correct + path when distributing a standalone client + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :rtype: dict + """ + return {} + + @classmethod + def missing_updown_scripts(kls): + """ + Returns what updown scripts are missing. + :rtype: list + """ + leap_assert(kls.UPDOWN_FILES is not None, + "Need to define UPDOWN_FILES for this particular " + "auncher before calling this method") + file_exist = partial(_has_updown_scripts, warn=False) + zipped = zip(kls.UPDOWN_FILES, map(file_exist, kls.UPDOWN_FILES)) + missing = filter(lambda (path, exists): exists is False, zipped) + return [path for path, exists in missing] + + @classmethod + def missing_other_files(kls): + """ + Returns what other important files are missing during startup. + Same as missing_updown_scripts but does not check for exec bit. + :rtype: list + """ + leap_assert(kls.UPDOWN_FILES is not None, + "Need to define OTHER_FILES for this particular " + "auncher before calling this method") + file_exist = partial(_has_other_files, warn=False) + zipped = zip(kls.OTHER_FILES, map(file_exist, kls.OTHER_FILES)) + missing = filter(lambda (path, exists): exists is False, zipped) + return [path for path, exists in missing] + + +def get_platform_launcher(): + launcher = globals()[platform.system() + "VPNLauncher"] + leap_assert(launcher, "Unimplemented platform launcher: %s" % + (platform.system(),)) + return launcher() + + +def _is_pkexec_in_system(): + """ + Checks the existence of the pkexec binary in system. + """ + pkexec_path = which('pkexec') + if len(pkexec_path) == 0: + return False + return True + + +def _has_updown_scripts(path, warn=True): + """ + Checks the existence of the up/down scripts and its + exec bit if applicable. + + :param path: the path to be checked + :type path: str + + :param warn: whether we should log the absence + :type warn: bool + + :rtype: bool + """ + is_file = os.path.isfile(path) + if warn and not is_file: + logger.error("Could not find up/down script %s. " + "Might produce DNS leaks." % (path,)) + + # XXX check if applies in win + is_exe = False + try: + is_exe = (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] != 0) + except OSError as e: + logger.warn("%s" % (e,)) + if warn and not is_exe: + logger.error("Up/down script %s is not executable. " + "Might produce DNS leaks." % (path,)) + return is_file and is_exe + + +def _has_other_files(path, warn=True): + """ + Checks the existence of other important files. + + :param path: the path to be checked + :type path: str + + :param warn: whether we should log the absence + :type warn: bool + + :rtype: bool + """ + is_file = os.path.isfile(path) + if warn and not is_file: + logger.warning("Could not find file during checks: %s. " % ( + path,)) + return is_file + + +def _is_auth_agent_running(): + """ + Checks if a polkit daemon is running. + + :return: True if it's running, False if it's not. + :rtype: boolean + """ + ps = 'ps aux | grep polkit-%s-authentication-agent-1' + opts = (ps % case for case in ['[g]nome', '[k]de']) + is_running = map(lambda l: commands.getoutput(l), opts) + return any(is_running) + + +def _try_to_launch_agent(): + """ + Tries to launch a polkit daemon. + """ + opts = [ + "/usr/lib/policykit-1-gnome/polkit-gnome-authentication-agent-1", + # XXX add kde thing here + ] + for cmd in opts: + try: + subprocess.Popen([cmd], shell=True) + except: + pass + + +class LinuxVPNLauncher(VPNLauncher): + """ + VPN launcher for the Linux platform + """ + + PKEXEC_BIN = 'pkexec' + OPENVPN_BIN = 'openvpn' + SYSTEM_CONFIG = "/etc/leap" + UP_DOWN_FILE = "resolv-update" + UP_DOWN_PATH = "%s/%s" % (SYSTEM_CONFIG, UP_DOWN_FILE) + + # We assume this is there by our openvpn dependency, and + # we will put it there on the bundle too. + # TODO adapt to the bundle path. + OPENVPN_DOWN_ROOT_BASE = "/usr/lib/openvpn/" + OPENVPN_DOWN_ROOT_FILE = "openvpn-plugin-down-root.so" + OPENVPN_DOWN_ROOT_PATH = "%s/%s" % ( + OPENVPN_DOWN_ROOT_BASE, + OPENVPN_DOWN_ROOT_FILE) + + POLKIT_BASE = "/usr/share/polkit-1/actions" + POLKIT_FILE = "net.openvpn.gui.leap.policy" + POLKIT_PATH = "%s/%s" % (POLKIT_BASE, POLKIT_FILE) + + UPDOWN_FILES = (UP_DOWN_PATH,) + OTHER_FILES = (POLKIT_PATH,) + + @classmethod + def cmd_for_missing_scripts(kls, frompath): + """ + Returns a command that can copy the missing scripts. + :rtype: str + """ + to = kls.SYSTEM_CONFIG + cmd = "#!/bin/sh\nset -e\nmkdir -p %s\ncp %s/%s %s\ncp %s/%s %s" % ( + to, + frompath, kls.UP_DOWN_FILE, to, + frompath, kls.POLKIT_FILE, kls.POLKIT_PATH) + return cmd + + @classmethod + def maybe_pkexec(kls): + """ + Checks whether pkexec is available in the system, and + returns the path if found. + + Might raise EIPNoPkexecAvailable or EIPNoPolkitAuthAgentAvailable + + :returns: a list of the paths where pkexec is to be found + :rtype: list + """ + if _is_pkexec_in_system(): + if not _is_auth_agent_running(): + _try_to_launch_agent() + if _is_auth_agent_running(): + pkexec_possibilities = which(kls.PKEXEC_BIN) + leap_assert(len(pkexec_possibilities) > 0, + "We couldn't find pkexec") + return pkexec_possibilities + else: + logger.warning("No polkit auth agent found. pkexec " + + "will use its own auth agent.") + raise EIPNoPolkitAuthAgentAvailable() + else: + logger.warning("System has no pkexec") + raise EIPNoPkexecAvailable() + + @classmethod + def maybe_down_plugin(kls): + """ + Returns the path of the openvpn down-root-plugin, searching first + in the relative path for the standalone bundle, and then in the system + path where the debian package puts it. + + :returns: the path where the plugin was found, or None + :rtype: str or None + """ + cwd = os.getcwd() + rel_path_in_bundle = os.path.join( + 'apps', 'eip', 'files', kls.OPENVPN_DOWN_ROOT_FILE) + abs_path_in_bundle = os.path.join(cwd, rel_path_in_bundle) + if os.path.isfile(abs_path_in_bundle): + return abs_path_in_bundle + abs_path_in_system = kls.OPENVPN_DOWN_ROOT_FILE + if os.path.isfile(abs_path_in_system): + return abs_path_in_system + + logger.warning("We could not find the down-root-plugin, so no updown " + "scripts will be run. DNS leaks are likely!") + return None + + def get_vpn_command(self, eipconfig=None, providerconfig=None, + socket_host=None, socket_port="unix"): + """ + Returns the platform dependant vpn launching command. It will + look for openvpn in the regular paths and algo in + path_prefix/apps/eip/ (in case standalone is set) + + Might raise VPNException. + + :param eipconfig: eip configuration object + :type eipconfig: EIPConfig + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + + :return: A VPN command ready to be launched + :rtype: list + """ + leap_assert(eipconfig, "We need an eip config") + leap_assert_type(eipconfig, EIPConfig) + leap_assert(providerconfig, "We need a provider config") + leap_assert_type(providerconfig, ProviderConfig) + leap_assert(socket_host, "We need a socket host!") + leap_assert(socket_port, "We need a socket port!") + + kwargs = {} + if ProviderConfig.standalone: + kwargs['path_extension'] = os.path.join( + providerconfig.get_path_prefix(), + "..", "apps", "eip") + + openvpn_possibilities = which(self.OPENVPN_BIN, **kwargs) + + if len(openvpn_possibilities) == 0: + raise OpenVPNNotFoundException() + + openvpn = first(openvpn_possibilities) + args = [] + + pkexec = self.maybe_pkexec() + if pkexec: + args.append(openvpn) + openvpn = first(pkexec) + + # TODO: handle verbosity + + gateway_selector = VPNGatewaySelector(eipconfig) + gateways = gateway_selector.get_gateways() + + logger.debug("Using gateways ips: {}".format(', '.join(gateways))) + + for gw in gateways: + args += ['--remote', gw, '1194', 'udp'] + + args += [ + '--client', + '--dev', 'tun', + '--persist-tun', + '--persist-key', + '--tls-client', + '--remote-cert-tls', + 'server' + ] + + openvpn_configuration = eipconfig.get_openvpn_configuration() + + for key, value in openvpn_configuration.items(): + args += ['--%s' % (key,), value] + + args += [ + '--user', getpass.getuser(), + '--group', grp.getgrgid(os.getgroups()[-1]).gr_name + ] + + if socket_port == "unix": # that's always the case for linux + args += [ + '--management-client-user', getpass.getuser() + ] + + args += [ + '--management-signal', + '--management', socket_host, socket_port, + '--script-security', '2' + ] + + plugin_path = self.maybe_down_plugin() + # If we do not have the down plugin neither in the bundle + # nor in the system, we do not do updown scripts. The alternative + # is leaving the user without the ability to restore dns and routes + # to its original state. + + if plugin_path and _has_updown_scripts(self.UP_DOWN_PATH): + args += [ + '--up', self.UP_DOWN_PATH, + '--down', self.UP_DOWN_PATH, + '--plugin', plugin_path, + '\'script_type=down %s\'' % self.UP_DOWN_PATH + ] + + args += [ + '--cert', eipconfig.get_client_cert_path(providerconfig), + '--key', eipconfig.get_client_cert_path(providerconfig), + '--ca', providerconfig.get_ca_cert_path() + ] + + logger.debug("Running VPN with command:") + logger.debug("%s %s" % (openvpn, " ".join(args))) + + return [openvpn] + args + + def get_vpn_env(self, providerconfig): + """ + Returns a dictionary with the custom env for the platform. + This is mainly used for setting LD_LIBRARY_PATH to the correct + path when distributing a standalone client + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :rtype: dict + """ + leap_assert(providerconfig, "We need a provider config") + leap_assert_type(providerconfig, ProviderConfig) + + return {"LD_LIBRARY_PATH": os.path.join( + providerconfig.get_path_prefix(), + "..", "lib")} + + +class DarwinVPNLauncher(VPNLauncher): + """ + VPN launcher for the Darwin Platform + """ + + COCOASUDO = "cocoasudo" + # XXX need magic translate for this string + SUDO_MSG = ("LEAP needs administrative privileges to run " + "Encrypted Internet.") + + INSTALL_PATH = "/Applications/LEAP\ Client.app" + OPENVPN_BIN = 'openvpn.leap' + OPENVPN_PATH = "%s/Contents/Resources/openvpn" % (INSTALL_PATH,) + + UP_SCRIPT = "%s/client.up.sh" % (OPENVPN_PATH,) + DOWN_SCRIPT = "%s/client.down.sh" % (OPENVPN_PATH,) + OPENVPN_DOWN_PLUGIN = '%s/openvpn-down-root.so' % (OPENVPN_PATH,) + + UPDOWN_FILES = (UP_SCRIPT, DOWN_SCRIPT, OPENVPN_DOWN_PLUGIN) + + @classmethod + def cmd_for_missing_scripts(kls, frompath): + """ + Returns a command that can copy the missing scripts. + :rtype: str + """ + to = kls.OPENVPN_PATH + cmd = "#!/bin/sh\nmkdir -p %s\ncp \"%s/\"* %s" % (to, frompath, to) + return cmd + + def get_cocoasudo_cmd(self): + """ + Returns a string with the cocoasudo command needed to run openvpn + as admin with a nice password prompt. The actual command needs to be + appended. + + :rtype: (str, list) + """ + iconpath = os.path.abspath(os.path.join( + os.getcwd(), + "../../../Resources/leap-client.tiff")) + has_icon = os.path.isfile(iconpath) + args = ["--icon=%s" % iconpath] if has_icon else [] + args.append("--prompt=%s" % (self.SUDO_MSG,)) + + return self.COCOASUDO, args + + def get_vpn_command(self, eipconfig=None, providerconfig=None, + socket_host=None, socket_port="unix"): + """ + Returns the platform dependant vpn launching command + + Might raise VPNException. + + :param eipconfig: eip configuration object + :type eipconfig: EIPConfig + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + + :return: A VPN command ready to be launched + :rtype: list + """ + leap_assert(eipconfig, "We need an eip config") + leap_assert_type(eipconfig, EIPConfig) + leap_assert(providerconfig, "We need a provider config") + leap_assert_type(providerconfig, ProviderConfig) + leap_assert(socket_host, "We need a socket host!") + leap_assert(socket_port, "We need a socket port!") + + kwargs = {} + if ProviderConfig.standalone: + kwargs['path_extension'] = os.path.join( + providerconfig.get_path_prefix(), + "..", "apps", "eip") + + openvpn_possibilities = which( + self.OPENVPN_BIN, + **kwargs) + if len(openvpn_possibilities) == 0: + raise OpenVPNNotFoundException() + + openvpn = first(openvpn_possibilities) + args = [openvpn] + + # TODO: handle verbosity + + gateway_selector = VPNGatewaySelector(eipconfig) + gateways = gateway_selector.get_gateways() + + logger.debug("Using gateways ips: {gw}".format( + gw=', '.join(gateways))) + + for gw in gateways: + args += ['--remote', gw, '1194', 'udp'] + + args += [ + '--client', + '--dev', 'tun', + '--persist-tun', + '--persist-key', + '--tls-client', + '--remote-cert-tls', + 'server' + ] + + openvpn_configuration = eipconfig.get_openvpn_configuration() + for key, value in openvpn_configuration.items(): + args += ['--%s' % (key,), value] + + user = getpass.getuser() + args += [ + '--user', user, + '--group', grp.getgrgid(os.getgroups()[-1]).gr_name + ] + + if socket_port == "unix": + args += [ + '--management-client-user', user + ] + + args += [ + '--management-signal', + '--management', socket_host, socket_port, + '--script-security', '2' + ] + + if _has_updown_scripts(self.UP_SCRIPT): + args += [ + '--up', self.UP_SCRIPT, + ] + + if _has_updown_scripts(self.DOWN_SCRIPT): + args += [ + '--down', self.DOWN_SCRIPT] + + # should have the down script too + if _has_updown_scripts(self.OPENVPN_DOWN_PLUGIN): + args += [ + '--plugin', self.OPENVPN_DOWN_PLUGIN, + '\'%s\'' % self.DOWN_SCRIPT + ] + + # we set user to be passed to the up/down scripts + args += [ + '--setenv', "LEAPUSER", "%s" % (user,)] + + args += [ + '--cert', eipconfig.get_client_cert_path(providerconfig), + '--key', eipconfig.get_client_cert_path(providerconfig), + '--ca', providerconfig.get_ca_cert_path() + ] + + command, cargs = self.get_cocoasudo_cmd() + cmd_args = cargs + args + + logger.debug("Running VPN with command:") + logger.debug("%s %s" % (command, " ".join(cmd_args))) + + return [command] + cmd_args + + def get_vpn_env(self, providerconfig): + """ + Returns a dictionary with the custom env for the platform. + This is mainly used for setting LD_LIBRARY_PATH to the correct + path when distributing a standalone client + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :rtype: dict + """ + return {"DYLD_LIBRARY_PATH": os.path.join( + providerconfig.get_path_prefix(), + "..", "lib")} + + +class WindowsVPNLauncher(VPNLauncher): + """ + VPN launcher for the Windows platform + """ + + OPENVPN_BIN = 'openvpn_leap.exe' + + # XXX UPDOWN_FILES ... we do not have updown files defined yet! + + def get_vpn_command(self, eipconfig=None, providerconfig=None, + socket_host=None, socket_port="9876"): + """ + Returns the platform dependant vpn launching command. It will + look for openvpn in the regular paths and algo in + path_prefix/apps/eip/ (in case standalone is set) + + Might raise VPNException. + + :param eipconfig: eip configuration object + :type eipconfig: EIPConfig + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + + :return: A VPN command ready to be launched + :rtype: list + """ + leap_assert(eipconfig, "We need an eip config") + leap_assert_type(eipconfig, EIPConfig) + leap_assert(providerconfig, "We need a provider config") + leap_assert_type(providerconfig, ProviderConfig) + leap_assert(socket_host, "We need a socket host!") + leap_assert(socket_port, "We need a socket port!") + leap_assert(socket_port != "unix", + "We cannot use unix sockets in windows!") + + openvpn_possibilities = which( + self.OPENVPN_BIN, + path_extension=os.path.join(providerconfig.get_path_prefix(), + "..", "apps", "eip")) + + if len(openvpn_possibilities) == 0: + raise OpenVPNNotFoundException() + + openvpn = first(openvpn_possibilities) + args = [] + + # TODO: handle verbosity + + gateway_selector = VPNGatewaySelector(eipconfig) + gateways = gateway_selector.get_gateways() + + logger.debug("Using gateways ips: {}".format(', '.join(gateways))) + + for gw in gateways: + args += ['--remote', gw, '1194', 'udp'] + + args += [ + '--client', + '--dev', 'tun', + '--persist-tun', + '--persist-key', + '--tls-client', + '--remote-cert-tls', + 'server' + ] + + openvpn_configuration = eipconfig.get_openvpn_configuration() + for key, value in openvpn_configuration.items(): + args += ['--%s' % (key,), value] + + args += [ + '--user', getpass.getuser(), + #'--group', grp.getgrgid(os.getgroups()[-1]).gr_name + ] + args += [ + '--management-signal', + '--management', socket_host, socket_port, + '--script-security', '2' + ] + args += [ + '--cert', eipconfig.get_client_cert_path(providerconfig), + '--key', eipconfig.get_client_cert_path(providerconfig), + '--ca', providerconfig.get_ca_cert_path() + ] + + logger.debug("Running VPN with command:") + logger.debug("%s %s" % (openvpn, " ".join(args))) + + return [openvpn] + args + + def get_vpn_env(self, providerconfig): + """ + Returns a dictionary with the custom env for the platform. + This is mainly used for setting LD_LIBRARY_PATH to the correct + path when distributing a standalone client + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :rtype: dict + """ + return {} + + +if __name__ == "__main__": + logger = logging.getLogger(name='leap') + logger.setLevel(logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '%(asctime)s ' + '- %(name)s - %(levelname)s - %(message)s') + console.setFormatter(formatter) + logger.addHandler(console) + + try: + abs_launcher = VPNLauncher() + except Exception as e: + assert isinstance(e, TypeError), "Something went wrong" + print "Abstract Prefixer class is working as expected" + + vpnlauncher = get_platform_launcher() + + eipconfig = EIPConfig() + if eipconfig.load("leap/providers/bitmask.net/eip-service.json"): + provider = ProviderConfig() + if provider.load("leap/providers/bitmask.net/provider.json"): + vpnlauncher.get_vpn_command(eipconfig=eipconfig, + providerconfig=provider, + socket_host="/blah") diff --git a/src/leap/services/eip/vpnprocess.py b/src/leap/services/eip/vpnprocess.py new file mode 100644 index 00000000..0ec56ae7 --- /dev/null +++ b/src/leap/services/eip/vpnprocess.py @@ -0,0 +1,718 @@ +# -*- coding: utf-8 -*- +# vpnprocess.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +VPN Manager, spawned in a custom processProtocol. +""" +import logging +import os +import psutil +import psutil.error +import shutil +import socket + +from PySide import QtCore + +from leap.common.check import leap_assert, leap_assert_type +from leap.config.providerconfig import ProviderConfig +from leap.services.eip.vpnlaunchers import get_platform_launcher +from leap.services.eip.eipconfig import EIPConfig +from leap.services.eip.udstelnet import UDSTelnet + +logger = logging.getLogger(__name__) +vpnlog = logging.getLogger('leap.openvpn') + +from twisted.internet import protocol +from twisted.internet import defer +from twisted.internet.task import LoopingCall +from twisted.internet import error as internet_error + + +class VPNSignals(QtCore.QObject): + """ + These are the signals that we use to let the UI know + about the events we are polling. + They are instantiated in the VPN object and passed along + till the VPNProcess. + """ + state_changed = QtCore.Signal(dict) + status_changed = QtCore.Signal(dict) + process_finished = QtCore.Signal(int) + + def __init__(self): + QtCore.QObject.__init__(self) + + +class VPN(object): + """ + This is the high-level object that the GUI is dealing with. + It exposes the start and terminate methods. + + On start, it spawns a VPNProcess instance that will use a vpnlauncher + suited for the running platform and connect to the management interface + opened by the openvpn process, executing commands over that interface on + demand. + """ + TERMINATE_MAXTRIES = 10 + TERMINATE_WAIT = 1 # secs + + def __init__(self): + """ + Instantiate empty attributes and get a copy + of a QObject containing the QSignals that we will pass along + to the VPNManager. + """ + from twisted.internet import reactor + self._vpnproc = None + self._pollers = [] + self._reactor = reactor + self._qtsigs = VPNSignals() + + @property + def qtsigs(self): + return self._qtsigs + + def start(self, *args, **kwargs): + """ + Starts the openvpn subprocess. + + :param args: args to be passed to the VPNProcess + :type args: tuple + + :param kwargs: kwargs to be passed to the VPNProcess + :type kwargs: dict + """ + kwargs['qtsigs'] = self.qtsigs + + # start the main vpn subprocess + vpnproc = VPNProcess(*args, **kwargs) + + # XXX Should stop if already running ------- + if vpnproc.get_openvpn_process(): + logger.warning("Another vpnprocess is running!") + + cmd = vpnproc.getCommand() + env = os.environ + for key, val in vpnproc.vpn_env.items(): + env[key] = val + + self._reactor.spawnProcess(vpnproc, cmd[0], cmd, env) + self._vpnproc = vpnproc + + # add pollers for status and state + # this could be extended to a collection of + # generic watchers + + poll_list = [LoopingCall(vpnproc.pollStatus), + LoopingCall(vpnproc.pollState)] + self._pollers.extend(poll_list) + self._start_pollers() + + def _kill_if_left_alive(self, tries=0): + """ + Check if the process is still alive, and sends a + SIGKILL after a timeout period. + + :param tries: counter of tries, used in recursion + :type tries: int + """ + from twisted.internet import reactor + while tries < self.TERMINATE_MAXTRIES: + if self._vpnproc.transport.pid is None: + logger.debug("Process has been happily terminated.") + return + else: + logger.debug("Process did not die, waiting...") + tries += 1 + reactor.callLater(self.TERMINATE_WAIT, + self._kill_if_left_alive, tries) + + # after running out of patience, we try a killProcess + logger.debug("Process did not died. Sending a SIGKILL.") + self.killit() + + def killit(self): + """ + Sends a kill signal to the process. + """ + self._stop_pollers() + self._vpnproc.aborted = True + self._vpnproc.killProcess() + + def terminate(self, shutdown=False): + """ + Stops the openvpn subprocess. + + Attempts to send a SIGTERM first, and after a timeout + it sends a SIGKILL. + """ + from twisted.internet import reactor + self._stop_pollers() + + # First we try to be polite and send a SIGTERM... + if self._vpnproc: + self._sentterm = True + self._vpnproc.terminate_openvpn(shutdown=shutdown) + + # ...but we also trigger a countdown to be unpolite + # if strictly needed. + reactor.callLater( + self.TERMINATE_WAIT, self._kill_if_left_alive) + + def _start_pollers(self): + """ + Iterate through the registered observers + and start the looping call for them. + """ + for poller in self._pollers: + poller.start(VPNManager.POLL_TIME) + + def _stop_pollers(self): + """ + Iterate through the registered observers + and stop the looping calls if they are running. + """ + for poller in self._pollers: + if poller.running: + poller.stop() + self._pollers = [] + + +class VPNManager(object): + """ + This is a mixin that we use in the VPNProcess class. + Here we get together all methods related with the openvpn management + interface. + + A copy of a QObject containing signals as attributes is passed along + upon initialization, and we use that object to emit signals to qt-land. + + For more info about management methods:: + + zcat `dpkg -L openvpn | grep management` + """ + + # Timers, in secs + POLL_TIME = 0.5 + CONNECTION_RETRY_TIME = 1 + + TS_KEY = "ts" + STATUS_STEP_KEY = "status_step" + OK_KEY = "ok" + IP_KEY = "ip" + REMOTE_KEY = "remote" + + TUNTAP_READ_KEY = "tun_tap_read" + TUNTAP_WRITE_KEY = "tun_tap_write" + TCPUDP_READ_KEY = "tcp_udp_read" + TCPUDP_WRITE_KEY = "tcp_udp_write" + AUTH_READ_KEY = "auth_read" + + def __init__(self, qtsigs=None): + """ + Initializes the VPNManager. + + :param qtsigs: a QObject containing the Qt signals used by the UI + to give feedback about state changes. + :type qtsigs: QObject + """ + from twisted.internet import reactor + self._reactor = reactor + self._tn = None + self._qtsigs = qtsigs + self._aborted = False + + @property + def qtsigs(self): + return self._qtsigs + + @property + def aborted(self): + return self._aborted + + @aborted.setter + def aborted(self, value): + self._aborted = value + + def _seek_to_eof(self): + """ + Read as much as available. Position seek pointer to end of stream + """ + try: + self._tn.read_eager() + except EOFError: + logger.debug("Could not read from socket. Assuming it died.") + return + + def _send_command(self, command, until=b"END"): + """ + Sends a command to the telnet connection and reads until END + is reached. + + :param command: command to send + :type command: str + + :param until: byte delimiter string for reading command output + :type until: byte str + + :return: response read + :rtype: list + """ + leap_assert(self._tn, "We need a tn connection!") + + try: + self._tn.write("%s\n" % (command,)) + buf = self._tn.read_until(until, 2) + self._seek_to_eof() + blist = buf.split('\r\n') + if blist[-1].startswith(until): + del blist[-1] + return blist + else: + return [] + + except socket.error: + # XXX should get a counter and repeat only + # after mod X times. + logger.warning('socket error') + self._close_management_socket(announce=False) + return [] + + # XXX should move this to a errBack! + except Exception as e: + logger.warning("Error sending command %s: %r" % + (command, e)) + return [] + + def _close_management_socket(self, announce=True): + """ + Close connection to openvpn management interface. + """ + logger.debug('closing socket') + if announce: + self._tn.write("quit\n") + self._tn.read_all() + self._tn.get_socket().close() + self._tn = None + + def _connect_management(self, socket_host, socket_port): + """ + Connects to the management interface on the specified + socket_host socket_port. + + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + """ + if self.is_connected(): + self._close_management_socket() + + try: + self._tn = UDSTelnet(socket_host, socket_port) + + # XXX make password optional + # specially for win. we should generate + # the pass on the fly when invoking manager + # from conductor + + # self.tn.read_until('ENTER PASSWORD:', 2) + # self.tn.write(self.password + '\n') + # self.tn.read_until('SUCCESS:', 2) + if self._tn: + self._tn.read_eager() + + # XXX move this to the Errback + except Exception as e: + logger.warning("Could not connect to OpenVPN yet: %r" % (e,)) + self._tn = None + + def _connectCb(self, *args): + """ + Callback for connection. + + :param args: not used + """ + if self._tn: + logger.info('connected to management') + + def _connectErr(self, failure): + """ + Errorback for connection. + + :param failure: Failure + """ + logger.warning(failure) + + def connect_to_management(self, host, port): + """ + Connect to a management interface. + + :param host: the host of the management interface + :type host: str + + :param port: the port of the management interface + :type port: str + + :returns: a deferred + """ + self.connectd = defer.maybeDeferred( + self._connect_management, host, port) + self.connectd.addCallbacks(self._connectCb, self._connectErr) + return self.connectd + + def is_connected(self): + """ + Returns the status of the management interface. + + :returns: True if connected, False otherwise + :rtype: bool + """ + return True if self._tn else False + + def try_to_connect_to_management(self, retry=0): + """ + Attempts to connect to a management interface, and retries + after CONNECTION_RETRY_TIME if not successful. + + :param retry: number of the retry + :type retry: int + """ + # TODO decide about putting a max_lim to retries and signaling + # an error. + if not self.aborted and not self.is_connected(): + self.connect_to_management(self._socket_host, self._socket_port) + self._reactor.callLater( + self.CONNECTION_RETRY_TIME, + self.try_to_connect_to_management, retry + 1) + + def _parse_state_and_notify(self, output): + """ + Parses the output of the state command and emits state_changed + signal when the state changes. + + :param output: list of lines that the state command printed as + its output + :type output: list + """ + for line in output: + stripped = line.strip() + if stripped == "END": + continue + parts = stripped.split(",") + if len(parts) < 5: + continue + ts, status_step, ok, ip, remote = parts + + state_dict = { + self.TS_KEY: ts, + self.STATUS_STEP_KEY: status_step, + self.OK_KEY: ok, + self.IP_KEY: ip, + self.REMOTE_KEY: remote + } + + if state_dict != self._last_state: + self.qtsigs.state_changed.emit(state_dict) + self._last_state = state_dict + + def _parse_status_and_notify(self, output): + """ + Parses the output of the status command and emits + status_changed signal when the status changes. + + :param output: list of lines that the status command printed + as its output + :type output: list + """ + tun_tap_read = "" + tun_tap_write = "" + tcp_udp_read = "" + tcp_udp_write = "" + auth_read = "" + for line in output: + stripped = line.strip() + if stripped.endswith("STATISTICS") or stripped == "END": + continue + parts = stripped.split(",") + if len(parts) < 2: + continue + if parts[0].strip() == "TUN/TAP read bytes": + tun_tap_read = parts[1] + elif parts[0].strip() == "TUN/TAP write bytes": + tun_tap_write = parts[1] + elif parts[0].strip() == "TCP/UDP read bytes": + tcp_udp_read = parts[1] + elif parts[0].strip() == "TCP/UDP write bytes": + tcp_udp_write = parts[1] + elif parts[0].strip() == "Auth read bytes": + auth_read = parts[1] + + status_dict = { + self.TUNTAP_READ_KEY: tun_tap_read, + self.TUNTAP_WRITE_KEY: tun_tap_write, + self.TCPUDP_READ_KEY: tcp_udp_read, + self.TCPUDP_WRITE_KEY: tcp_udp_write, + self.AUTH_READ_KEY: auth_read + } + + if status_dict != self._last_status: + self.qtsigs.status_changed.emit(status_dict) + self._last_status = status_dict + + def get_state(self): + """ + Notifies the gui of the output of the state command over + the openvpn management interface. + """ + if self.is_connected(): + return self._parse_state_and_notify(self._send_command("state")) + + def get_status(self): + """ + Notifies the gui of the output of the status command over + the openvpn management interface. + """ + if self.is_connected(): + return self._parse_status_and_notify(self._send_command("status")) + + @property + def vpn_env(self): + """ + Return a dict containing the vpn environment to be used. + """ + return self._launcher.get_vpn_env(self._providerconfig) + + def terminate_openvpn(self, shutdown=False): + """ + Attempts to terminate openvpn by sending a SIGTERM. + """ + if self.is_connected(): + self._send_command("signal SIGTERM") + if shutdown: + self._cleanup_tempfiles() + + def _cleanup_tempfiles(self): + """ + Remove all temporal files we might have left behind. + + Iif self.port is 'unix', we have created a temporal socket path that, + under normal circumstances, we should be able to delete. + """ + if self._socket_port == "unix": + logger.debug('cleaning socket file temp folder') + tempfolder = os.path.split(self._socket_host)[0] # XXX use `first` + if os.path.isdir(tempfolder): + try: + shutil.rmtree(tempfolder) + except OSError: + logger.error('could not delete tmpfolder %s' % tempfolder) + + # --------------------------------------------------- + # XXX old methods, not adapted to twisted process yet + + def get_openvpn_process(self): + """ + Looks for openvpn instances running. + + :rtype: process + """ + openvpn_process = None + for p in psutil.process_iter(): + try: + # XXX Not exact! + # Will give false positives. + # we should check that cmdline BEGINS + # with openvpn or with our wrapper + # (pkexec / osascript / whatever) + if "openvpn" in ' '.join(p.cmdline): + openvpn_process = p + break + except psutil.error.AccessDenied: + pass + return openvpn_process + + def _stop_if_already_running(self): + """ + Checks if VPN is already running and tries to stop it. + + :return: True if stopped, False otherwise + """ + # TODO cleanup this + process = self._get_openvpn_process() + if process: + logger.debug("OpenVPN is already running, trying to stop it...") + cmdline = process.cmdline + + manag_flag = "--management" + if isinstance(cmdline, list) and manag_flag in cmdline: + try: + index = cmdline.index(manag_flag) + host = cmdline[index + 1] + port = cmdline[index + 2] + logger.debug("Trying to connect to %s:%s" + % (host, port)) + self._connect_to_management(host, port) + self._send_command("signal SIGTERM") + self._tn.close() + self._tn = None + #self._disconnect_management() + except Exception as e: + logger.warning("Problem trying to terminate OpenVPN: %r" + % (e,)) + + process = self._get_openvpn_process() + if process is None: + logger.warning("Unabled to terminate OpenVPN") + return True + else: + return False + return True + + +class VPNProcess(protocol.ProcessProtocol, VPNManager): + """ + A ProcessProtocol class that can be used to spawn a process that will + launch openvpn and connect to its management interface to control it + programmatically. + """ + + def __init__(self, eipconfig, providerconfig, socket_host, socket_port, + qtsigs): + """ + :param eipconfig: eip configuration object + :type eipconfig: EIPConfig + + :param providerconfig: provider specific configuration + :type providerconfig: ProviderConfig + + :param socket_host: either socket path (unix) or socket IP + :type socket_host: str + + :param socket_port: either string "unix" if it's a unix + socket, or port otherwise + :type socket_port: str + + :param qtsigs: a QObject containing the Qt signals used to notify the + UI. + :type qtsigs: QObject + """ + VPNManager.__init__(self, qtsigs=qtsigs) + leap_assert_type(eipconfig, EIPConfig) + leap_assert_type(providerconfig, ProviderConfig) + leap_assert_type(qtsigs, QtCore.QObject) + + #leap_assert(not self.isRunning(), "Starting process more than once!") + + self._eipconfig = eipconfig + self._providerconfig = providerconfig + self._socket_host = socket_host + self._socket_port = socket_port + + self._launcher = get_platform_launcher() + + self._last_state = None + self._last_status = None + self._alive = False + + # processProtocol methods + + def connectionMade(self): + """ + Called when the connection is made. + + .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa + """ + self._alive = True + self.aborted = False + self.try_to_connect_to_management() + + def outReceived(self, data): + """ + Called when new data is available on stdout. + + :param data: the data read on stdout + + .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa + """ + # truncate the newline + # should send this to the logging window + vpnlog.info(data[:-1]) + + def processExited(self, reason): + """ + Called when the child process exits. + + .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa + """ + exit_code = reason.value.exitCode + if isinstance(exit_code, int): + logger.debug("processExited, status %d" % (exit_code,)) + self.qtsigs.process_finished.emit(exit_code) + self._alive = False + + def processEnded(self, reason): + """ + Called when the child process exits and all file descriptors associated + with it have been closed. + + .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa + """ + exit_code = reason.value.exitCode + if isinstance(exit_code, int): + logger.debug("processEnded, status %d" % (exit_code,)) + + # polling + + def pollStatus(self): + """ + Polls connection status. + """ + if self._alive: + self.get_status() + + def pollState(self): + """ + Polls connection state. + """ + if self._alive: + self.get_state() + + # launcher + + def getCommand(self): + """ + Gets the vpn command from the aproppriate launcher. + """ + cmd = self._launcher.get_vpn_command( + eipconfig=self._eipconfig, + providerconfig=self._providerconfig, + socket_host=self._socket_host, + socket_port=self._socket_port) + return map(str, cmd) + + # shutdown + + def killProcess(self): + """ + Sends the KILL signal to the running process. + """ + try: + self.transport.signalProcess('KILL') + except internet_error.ProcessExitedAlready: + logger.debug('Process Exited Already') |