summaryrefslogtreecommitdiff
path: root/src/leap/services/eip
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/services/eip')
-rw-r--r--src/leap/services/eip/__init__.py0
-rw-r--r--src/leap/services/eip/eipbootstrapper.py178
-rw-r--r--src/leap/services/eip/eipconfig.py260
-rw-r--r--src/leap/services/eip/eipspec.py65
-rw-r--r--src/leap/services/eip/providerbootstrapper.py311
-rw-r--r--src/leap/services/eip/tests/__init__.py0
-rw-r--r--src/leap/services/eip/tests/test_eipbootstrapper.py347
-rw-r--r--src/leap/services/eip/tests/test_eipconfig.py313
-rw-r--r--src/leap/services/eip/tests/test_providerbootstrapper.py504
-rw-r--r--src/leap/services/eip/tests/test_vpngatewayselector.py131
-rw-r--r--src/leap/services/eip/tests/wrongcert.pem33
-rw-r--r--src/leap/services/eip/udstelnet.py60
-rw-r--r--src/leap/services/eip/vpnlaunchers.py774
-rw-r--r--src/leap/services/eip/vpnprocess.py718
14 files changed, 3694 insertions, 0 deletions
diff --git a/src/leap/services/eip/__init__.py b/src/leap/services/eip/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/services/eip/__init__.py
diff --git a/src/leap/services/eip/eipbootstrapper.py b/src/leap/services/eip/eipbootstrapper.py
new file mode 100644
index 00000000..b2af0aea
--- /dev/null
+++ b/src/leap/services/eip/eipbootstrapper.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+# eipbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+EIP bootstrapping
+"""
+
+import logging
+import os
+
+from PySide import QtCore
+
+from leap.common.check import leap_assert, leap_assert_type
+from leap.common import certs
+from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p
+from leap.config.providerconfig import ProviderConfig
+from leap.crypto.srpauth import SRPAuth
+from leap.services.eip.eipconfig import EIPConfig
+from leap.util.request_helpers import get_content
+from leap.services.abstractbootstrapper import AbstractBootstrapper
+
+logger = logging.getLogger(__name__)
+
+
+class EIPBootstrapper(AbstractBootstrapper):
+ """
+ Sets up EIP for a provider a series of checks and emits signals
+ after they are passed.
+ If a check fails, the subsequent checks are not executed
+ """
+
+ # All dicts returned are of the form
+ # {"passed": bool, "error": str}
+ download_config = QtCore.Signal(dict)
+ download_client_certificate = QtCore.Signal(dict)
+
+ def __init__(self):
+ AbstractBootstrapper.__init__(self)
+
+ self._provider_config = None
+ self._eip_config = None
+ self._download_if_needed = False
+
+ def _download_config(self, *args):
+ """
+ Downloads the EIP config for the given provider
+ """
+
+ leap_assert(self._provider_config,
+ "We need a provider configuration!")
+
+ logger.debug("Downloading EIP config for %s" %
+ (self._provider_config.get_domain(),))
+
+ self._eip_config = EIPConfig()
+
+ headers = {}
+ mtime = get_mtime(os.path.join(self._eip_config
+ .get_path_prefix(),
+ "leap",
+ "providers",
+ self._provider_config.get_domain(),
+ "eip-service.json"))
+
+ if self._download_if_needed and mtime:
+ headers['if-modified-since'] = mtime
+
+ # there is some confusion with this uri,
+ # it's in 1/config/eip, config/eip and config/1/eip...
+ config_uri = "%s/%s/config/eip-service.json" % (
+ self._provider_config.get_api_uri(),
+ self._provider_config.get_api_version())
+ logger.debug('Downloading eip config from: %s' % config_uri)
+
+ res = self._session.get(config_uri,
+ verify=self._provider_config
+ .get_ca_cert_path(),
+ headers=headers)
+ res.raise_for_status()
+
+ # Not modified
+ if res.status_code == 304:
+ logger.debug("EIP definition has not been modified")
+ else:
+ eip_definition, mtime = get_content(res)
+
+ self._eip_config.load(data=eip_definition, mtime=mtime)
+ self._eip_config.save(["leap",
+ "providers",
+ self._provider_config.get_domain(),
+ "eip-service.json"])
+
+ def _download_client_certificates(self, *args):
+ """
+ Downloads the EIP client certificate for the given provider
+ """
+ leap_assert(self._provider_config, "We need a provider configuration!")
+ leap_assert(self._eip_config, "We need an eip configuration!")
+
+ logger.debug("Downloading EIP client certificate for %s" %
+ (self._provider_config.get_domain(),))
+
+ client_cert_path = self._eip_config.\
+ get_client_cert_path(self._provider_config,
+ about_to_download=True)
+
+ # For re-download if something is wrong with the cert
+ self._download_if_needed = self._download_if_needed and \
+ not certs.should_redownload(client_cert_path)
+
+ if self._download_if_needed and \
+ os.path.exists(client_cert_path):
+ check_and_fix_urw_only(client_cert_path)
+ return
+
+ srp_auth = SRPAuth(self._provider_config)
+ session_id = srp_auth.get_session_id()
+ cookies = None
+ if session_id:
+ cookies = {"_session_id": session_id}
+ cert_uri = "%s/%s/cert" % (
+ self._provider_config.get_api_uri(),
+ self._provider_config.get_api_version())
+ logger.debug('getting cert from uri: %s' % cert_uri)
+ res = self._session.get(cert_uri,
+ verify=self._provider_config
+ .get_ca_cert_path(),
+ cookies=cookies)
+ res.raise_for_status()
+ client_cert = res.content
+
+ if not certs.is_valid_pemfile(client_cert):
+ raise Exception(self.tr("The downloaded certificate is not a "
+ "valid PEM file"))
+
+ mkdir_p(os.path.dirname(client_cert_path))
+
+ with open(client_cert_path, "w") as f:
+ f.write(client_cert)
+
+ check_and_fix_urw_only(client_cert_path)
+
+ def run_eip_setup_checks(self,
+ provider_config,
+ download_if_needed=False):
+ """
+ Starts the checks needed for a new eip setup
+
+ :param provider_config: Provider configuration
+ :type provider_config: ProviderConfig
+ """
+ leap_assert(provider_config, "We need a provider config!")
+ leap_assert_type(provider_config, ProviderConfig)
+
+ self._provider_config = provider_config
+ self._download_if_needed = download_if_needed
+
+ cb_chain = [
+ (self._download_config, self.download_config),
+ (self._download_client_certificates,
+ self.download_client_certificate)
+ ]
+
+ return self.addCallbackChain(cb_chain)
diff --git a/src/leap/services/eip/eipconfig.py b/src/leap/services/eip/eipconfig.py
new file mode 100644
index 00000000..9e3a9b29
--- /dev/null
+++ b/src/leap/services/eip/eipconfig.py
@@ -0,0 +1,260 @@
+# -*- coding: utf-8 -*-
+# eipconfig.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+Provider configuration
+"""
+import logging
+import os
+import re
+import time
+
+import ipaddr
+
+from leap.common.check import leap_assert, leap_assert_type
+from leap.common.config.baseconfig import BaseConfig
+from leap.config.providerconfig import ProviderConfig
+from leap.services.eip.eipspec import eipservice_config_spec
+
+logger = logging.getLogger(__name__)
+
+
+class VPNGatewaySelector(object):
+ """
+ VPN Gateway selector.
+ """
+ # http://www.timeanddate.com/time/map/
+ equivalent_timezones = {13: -11, 14: -10}
+
+ def __init__(self, eipconfig, tz_offset=None):
+ '''
+ Constructor for VPNGatewaySelector.
+
+ :param eipconfig: a valid EIP Configuration.
+ :type eipconfig: EIPConfig
+ :param tz_offset: use this offset as a local distance to GMT.
+ :type tz_offset: int
+ '''
+ leap_assert_type(eipconfig, EIPConfig)
+
+ self._local_offset = tz_offset
+ if tz_offset is None:
+ tz_offset = self._get_local_offset()
+
+ if tz_offset in self.equivalent_timezones:
+ tz_offset = self.equivalent_timezones[tz_offset]
+
+ self._local_offset = tz_offset
+
+ self._eipconfig = eipconfig
+
+ def get_gateways(self):
+ """
+ Returns the 4 best gateways, sorted by timezone proximity.
+
+ :rtype: list of IPv4Address or IPv6Address object.
+ """
+ gateways_timezones = []
+ locations = self._eipconfig.get_locations()
+ gateways = self._eipconfig.get_gateways()
+
+ for idx, gateway in enumerate(gateways):
+ gateway_location = gateway.get('location')
+ gateway_distance = 99 # if hasn't location -> should go last
+
+ if gateway_location is not None:
+ gw_offset = int(locations[gateway['location']]['timezone'])
+ if gw_offset in self.equivalent_timezones:
+ gw_offset = self.equivalent_timezones[gw_offset]
+
+ gateway_distance = self._get_timezone_distance(gw_offset)
+
+ ip = self._eipconfig.get_gateway_ip(idx)
+ gateways_timezones.append((ip, gateway_distance))
+
+ gateways_timezones = sorted(gateways_timezones,
+ key=lambda gw: gw[1])[:4]
+
+ gateways = [ip for ip, dist in gateways_timezones]
+ return gateways
+
+ def _get_timezone_distance(self, offset):
+ '''
+ Returns the distance between the local timezone and
+ the one with offset 'offset'.
+
+ :param offset: the distance of a timezone to GMT.
+ :type offset: int
+ :returns: distance between local offset and param offset.
+ :rtype: int
+ '''
+ timezones = range(-11, 13)
+ tz1 = offset
+ tz2 = self._local_offset
+ distance = abs(timezones.index(tz1) - timezones.index(tz2))
+ if distance > 12:
+ if tz1 < 0:
+ distance = timezones.index(tz1) + timezones[::-1].index(tz2)
+ else:
+ distance = timezones[::-1].index(tz1) + timezones.index(tz2)
+
+ return distance
+
+ def _get_local_offset(self):
+ '''
+ Returns the distance between GMT and the local timezone.
+
+ :rtype: int
+ '''
+ local_offset = time.timezone
+ if time.daylight:
+ local_offset = time.altzone
+
+ return local_offset / 3600
+
+
+class EIPConfig(BaseConfig):
+ """
+ Provider configuration abstraction class
+ """
+ OPENVPN_ALLOWED_KEYS = ("auth", "cipher", "tls-cipher")
+ OPENVPN_CIPHERS_REGEX = re.compile("[A-Z0-9\-]+")
+
+ def __init__(self):
+ BaseConfig.__init__(self)
+
+ def _get_spec(self):
+ """
+ Returns the spec object for the specific configuration
+ """
+ return eipservice_config_spec
+
+ def get_clusters(self):
+ # TODO: create an abstraction for clusters
+ return self._safe_get_value("clusters")
+
+ def get_gateways(self):
+ # TODO: create an abstraction for gateways
+ return self._safe_get_value("gateways")
+
+ def get_locations(self):
+ '''
+ Returns a list of locations
+
+ :rtype: dict
+ '''
+ return self._safe_get_value("locations")
+
+ def get_openvpn_configuration(self):
+ """
+ Returns a dictionary containing the openvpn configuration
+ parameters.
+
+ These are sanitized with alphanumeric whitelist.
+
+ :returns: openvpn configuration dict
+ :rtype: C{dict}
+ """
+ ovpncfg = self._safe_get_value("openvpn_configuration")
+ config = {}
+ for key, value in ovpncfg.items():
+ if key in self.OPENVPN_ALLOWED_KEYS and value is not None:
+ sanitized_val = self.OPENVPN_CIPHERS_REGEX.findall(value)
+ if len(sanitized_val) != 0:
+ _val = sanitized_val[0]
+ config[str(key)] = str(_val)
+ return config
+
+ def get_serial(self):
+ return self._safe_get_value("serial")
+
+ def get_version(self):
+ return self._safe_get_value("version")
+
+ def get_gateway_ip(self, index=0):
+ """
+ Returns the ip of the gateway.
+
+ :rtype: An IPv4Address or IPv6Address object.
+ """
+ gateways = self.get_gateways()
+ leap_assert(len(gateways) > 0, "We don't have any gateway!")
+ if index > len(gateways):
+ index = 0
+ logger.warning("Provided an unknown gateway index %s, " +
+ "defaulting to 0")
+ ip_addr_str = gateways[index]["ip_address"]
+
+ try:
+ ipaddr.IPAddress(ip_addr_str)
+ return ip_addr_str
+ except ValueError:
+ logger.error("Invalid ip address in config: %s" % (ip_addr_str,))
+ return None
+
+ def get_client_cert_path(self,
+ providerconfig=None,
+ about_to_download=False):
+ """
+ Returns the path to the certificate used by openvpn
+ """
+
+ leap_assert(providerconfig, "We need a provider")
+ leap_assert_type(providerconfig, ProviderConfig)
+
+ cert_path = os.path.join(self.get_path_prefix(),
+ "leap",
+ "providers",
+ providerconfig.get_domain(),
+ "keys",
+ "client",
+ "openvpn.pem")
+
+ if not about_to_download:
+ leap_assert(os.path.exists(cert_path),
+ "You need to download the certificate first")
+ logger.debug("Using OpenVPN cert %s" % (cert_path,))
+
+ return cert_path
+
+
+if __name__ == "__main__":
+ logger = logging.getLogger(name='leap')
+ logger.setLevel(logging.DEBUG)
+ console = logging.StreamHandler()
+ console.setLevel(logging.DEBUG)
+ formatter = logging.Formatter(
+ '%(asctime)s '
+ '- %(name)s - %(levelname)s - %(message)s')
+ console.setFormatter(formatter)
+ logger.addHandler(console)
+
+ eipconfig = EIPConfig()
+
+ try:
+ eipconfig.get_clusters()
+ except Exception as e:
+ assert isinstance(e, AssertionError), "Expected an assert"
+ print "Safe value getting is working"
+
+ if eipconfig.load("leap/providers/bitmask.net/eip-service.json"):
+ print eipconfig.get_clusters()
+ print eipconfig.get_gateways()
+ print eipconfig.get_locations()
+ print eipconfig.get_openvpn_configuration()
+ print eipconfig.get_serial()
+ print eipconfig.get_version()
diff --git a/src/leap/services/eip/eipspec.py b/src/leap/services/eip/eipspec.py
new file mode 100644
index 00000000..94ba674f
--- /dev/null
+++ b/src/leap/services/eip/eipspec.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# eipspec.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+eipservice_config_spec = {
+ 'description': 'sample eip service config',
+ 'type': 'object',
+ 'properties': {
+ 'serial': {
+ 'type': int,
+ 'default': 1,
+ 'required': ["True"]
+ },
+ 'version': {
+ 'type': int,
+ 'default': 1,
+ 'required': ["True"]
+ },
+ 'clusters': {
+ 'type': list,
+ 'default': [
+ {"label": {
+ "en": "Location Unknown"},
+ "name": "location_unknown"}]
+ },
+ 'gateways': {
+ 'type': list,
+ 'default': [
+ {"capabilities": {
+ "adblock": True,
+ "filter_dns": True,
+ "ports": ["80", "53", "443", "1194"],
+ "protocols": ["udp", "tcp"],
+ "transport": ["openvpn"],
+ "user_ips": False},
+ "cluster": "location_unknown",
+ "host": "location.example.org",
+ "ip_address": "127.0.0.1"}]
+ },
+ 'locations': {
+ 'type': dict,
+ 'default': {}
+ },
+ 'openvpn_configuration': {
+ 'type': dict,
+ 'default': {
+ "auth": None,
+ "cipher": None,
+ "tls-cipher": None}
+ }
+ }
+}
diff --git a/src/leap/services/eip/providerbootstrapper.py b/src/leap/services/eip/providerbootstrapper.py
new file mode 100644
index 00000000..754d0643
--- /dev/null
+++ b/src/leap/services/eip/providerbootstrapper.py
@@ -0,0 +1,311 @@
+# -*- coding: utf-8 -*-
+# providerbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+Provider bootstrapping
+"""
+import logging
+import socket
+import os
+
+import requests
+
+from PySide import QtCore
+
+from leap.common.certs import get_digest
+from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p
+from leap.common.check import leap_assert, leap_assert_type
+from leap.config.providerconfig import ProviderConfig
+from leap.util.request_helpers import get_content
+from leap.services.abstractbootstrapper import AbstractBootstrapper
+from leap.provider.supportedapis import SupportedAPIs
+
+
+logger = logging.getLogger(__name__)
+
+
+class UnsupportedProviderAPI(Exception):
+ """
+ Raised when attempting to use a provider with an incompatible API.
+ """
+ pass
+
+
+class ProviderBootstrapper(AbstractBootstrapper):
+ """
+ Given a provider URL performs a series of checks and emits signals
+ after they are passed.
+ If a check fails, the subsequent checks are not executed
+ """
+
+ # All dicts returned are of the form
+ # {"passed": bool, "error": str}
+ name_resolution = QtCore.Signal(dict)
+ https_connection = QtCore.Signal(dict)
+ download_provider_info = QtCore.Signal(dict)
+
+ download_ca_cert = QtCore.Signal(dict)
+ check_ca_fingerprint = QtCore.Signal(dict)
+ check_api_certificate = QtCore.Signal(dict)
+
+ def __init__(self, bypass_checks=False):
+ """
+ Constructor for provider bootstrapper object
+
+ :param bypass_checks: Set to true if the app should bypass
+ first round of checks for CA certificates at bootstrap
+ :type bypass_checks: bool
+ """
+ AbstractBootstrapper.__init__(self, bypass_checks)
+
+ self._domain = None
+ self._provider_config = None
+ self._download_if_needed = False
+
+ def _check_name_resolution(self):
+ """
+ Checks that the name resolution for the provider name works
+ """
+ leap_assert(self._domain, "Cannot check DNS without a domain")
+
+ logger.debug("Checking name resolution for %s" % (self._domain))
+
+ # We don't skip this check, since it's basic for the whole
+ # system to work
+ socket.gethostbyname(self._domain)
+
+ def _check_https(self, *args):
+ """
+ Checks that https is working and that the provided certificate
+ checks out
+ """
+
+ leap_assert(self._domain, "Cannot check HTTPS without a domain")
+
+ logger.debug("Checking https for %s" % (self._domain))
+
+ # We don't skip this check, since it's basic for the whole
+ # system to work
+
+ try:
+ res = self._session.get("https://%s" % (self._domain,),
+ verify=not self._bypass_checks)
+ res.raise_for_status()
+ except requests.exceptions.SSLError:
+ self._err_msg = self.tr("Provider certificate could "
+ "not be verified")
+ raise
+ except Exception:
+ self._err_msg = self.tr("Provider does not support HTTPS")
+ raise
+
+ def _download_provider_info(self, *args):
+ """
+ Downloads the provider.json defition
+ """
+ leap_assert(self._domain,
+ "Cannot download provider info without a domain")
+
+ logger.debug("Downloading provider info for %s" % (self._domain))
+
+ headers = {}
+ mtime = get_mtime(os.path.join(ProviderConfig()
+ .get_path_prefix(),
+ "leap",
+ "providers",
+ self._domain,
+ "provider.json"))
+ if self._download_if_needed and mtime:
+ headers['if-modified-since'] = mtime
+
+ res = self._session.get("https://%s/%s" % (self._domain,
+ "provider.json"),
+ headers=headers,
+ verify=not self._bypass_checks)
+ res.raise_for_status()
+
+ # Not modified
+ if res.status_code == 304:
+ logger.debug("Provider definition has not been modified")
+ else:
+ provider_definition, mtime = get_content(res)
+
+ provider_config = ProviderConfig()
+ provider_config.load(data=provider_definition, mtime=mtime)
+ provider_config.save(["leap",
+ "providers",
+ self._domain,
+ "provider.json"])
+
+ api_version = provider_config.get_api_version()
+ if SupportedAPIs.supports(api_version):
+ logger.debug("Provider definition has been modified")
+ else:
+ api_supported = ', '.join(SupportedAPIs.SUPPORTED_APIS)
+ error = ('Unsupported provider API version. '
+ 'Supported versions are: {}. '
+ 'Found: {}.').format(api_supported, api_version)
+
+ logger.error(error)
+ raise UnsupportedProviderAPI(error)
+
+ def run_provider_select_checks(self, domain, download_if_needed=False):
+ """
+ Populates the check queue.
+
+ :param domain: domain to check
+ :type domain: str
+
+ :param download_if_needed: if True, makes the checks do not
+ overwrite already downloaded data
+ :type download_if_needed: bool
+ """
+ leap_assert(domain and len(domain) > 0, "We need a domain!")
+
+ self._domain = domain
+ self._download_if_needed = download_if_needed
+
+ cb_chain = [
+ (self._check_name_resolution, self.name_resolution),
+ (self._check_https, self.https_connection),
+ (self._download_provider_info, self.download_provider_info)
+ ]
+
+ return self.addCallbackChain(cb_chain)
+
+ def _should_proceed_cert(self):
+ """
+ Returns False if the certificate already exists for the given
+ provider. True otherwise
+
+ :rtype: bool
+ """
+ leap_assert(self._provider_config, "We need a provider config!")
+
+ if not self._download_if_needed:
+ return True
+
+ return not os.path.exists(self._provider_config
+ .get_ca_cert_path(about_to_download=True))
+
+ def _download_ca_cert(self, *args):
+ """
+ Downloads the CA cert that is going to be used for the api URL
+ """
+
+ leap_assert(self._provider_config, "Cannot download the ca cert "
+ "without a provider config!")
+
+ logger.debug("Downloading ca cert for %s at %s" %
+ (self._domain, self._provider_config.get_ca_cert_uri()))
+
+ if not self._should_proceed_cert():
+ check_and_fix_urw_only(
+ self._provider_config
+ .get_ca_cert_path(about_to_download=True))
+ return
+
+ res = self._session.get(self._provider_config.get_ca_cert_uri(),
+ verify=not self._bypass_checks)
+ res.raise_for_status()
+
+ cert_path = self._provider_config.get_ca_cert_path(
+ about_to_download=True)
+ cert_dir = os.path.dirname(cert_path)
+ mkdir_p(cert_dir)
+ with open(cert_path, "w") as f:
+ f.write(res.content)
+
+ check_and_fix_urw_only(cert_path)
+
+ def _check_ca_fingerprint(self, *args):
+ """
+ Checks the CA cert fingerprint against the one provided in the
+ json definition
+ """
+ leap_assert(self._provider_config, "Cannot check the ca cert "
+ "without a provider config!")
+
+ logger.debug("Checking ca fingerprint for %s and cert %s" %
+ (self._domain,
+ self._provider_config.get_ca_cert_path()))
+
+ if not self._should_proceed_cert():
+ return
+
+ parts = self._provider_config.get_ca_cert_fingerprint().split(":")
+ leap_assert(len(parts) == 2, "Wrong fingerprint format")
+
+ method = parts[0].strip()
+ fingerprint = parts[1].strip()
+ cert_data = None
+ with open(self._provider_config.get_ca_cert_path()) as f:
+ cert_data = f.read()
+
+ leap_assert(len(cert_data) > 0, "Could not read certificate data")
+ digest = get_digest(cert_data, method)
+ leap_assert(digest == fingerprint,
+ "Downloaded certificate has a different fingerprint!")
+
+ def _check_api_certificate(self, *args):
+ """
+ Tries to make an API call with the downloaded cert and checks
+ if it validates against it
+ """
+ leap_assert(self._provider_config, "Cannot check the ca cert "
+ "without a provider config!")
+
+ logger.debug("Checking api certificate for %s and cert %s" %
+ (self._provider_config.get_api_uri(),
+ self._provider_config.get_ca_cert_path()))
+
+ if not self._should_proceed_cert():
+ return
+
+ test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(),
+ self._provider_config.get_api_version())
+ res = self._session.get(test_uri,
+ verify=self._provider_config
+ .get_ca_cert_path())
+ res.raise_for_status()
+
+ def run_provider_setup_checks(self,
+ provider_config,
+ download_if_needed=False):
+ """
+ Starts the checks needed for a new provider setup.
+
+ :param provider_config: Provider configuration
+ :type provider_config: ProviderConfig
+
+ :param download_if_needed: if True, makes the checks do not
+ overwrite already downloaded data.
+ :type download_if_needed: bool
+ """
+ leap_assert(provider_config, "We need a provider config!")
+ leap_assert_type(provider_config, ProviderConfig)
+
+ self._provider_config = provider_config
+ self._download_if_needed = download_if_needed
+
+ cb_chain = [
+ (self._download_ca_cert, self.download_ca_cert),
+ (self._check_ca_fingerprint, self.check_ca_fingerprint),
+ (self._check_api_certificate, self.check_api_certificate)
+ ]
+
+ return self.addCallbackChain(cb_chain)
diff --git a/src/leap/services/eip/tests/__init__.py b/src/leap/services/eip/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/services/eip/tests/__init__.py
diff --git a/src/leap/services/eip/tests/test_eipbootstrapper.py b/src/leap/services/eip/tests/test_eipbootstrapper.py
new file mode 100644
index 00000000..f2331eca
--- /dev/null
+++ b/src/leap/services/eip/tests/test_eipbootstrapper.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+# test_eipbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+
+"""
+Tests for the EIP Boostrapper checks
+
+These will be whitebox tests since we want to make sure the private
+implementation is checking what we expect.
+"""
+
+import os
+import mock
+import tempfile
+import time
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from nose.twistedtools import deferred, reactor
+from twisted.internet import threads
+from requests.models import Response
+
+from leap.common.testing.basetest import BaseLeapTest
+from leap.services.eip.eipbootstrapper import EIPBootstrapper
+from leap.services.eip.eipconfig import EIPConfig
+from leap.config.providerconfig import ProviderConfig
+from leap.crypto.tests import fake_provider
+from leap.common.files import mkdir_p
+from leap.crypto.srpauth import SRPAuth
+
+
+class EIPBootstrapperActiveTest(BaseLeapTest):
+ @classmethod
+ def setUpClass(cls):
+ BaseLeapTest.setUpClass()
+ factory = fake_provider.get_provider_factory()
+ http = reactor.listenTCP(0, factory)
+ https = reactor.listenSSL(
+ 0, factory,
+ fake_provider.OpenSSLServerContextFactory())
+ get_port = lambda p: p.getHost().port
+ cls.http_port = get_port(http)
+ cls.https_port = get_port(https)
+
+ def setUp(self):
+ self.eb = EIPBootstrapper()
+ self.old_pp = EIPConfig.get_path_prefix
+ self.old_save = EIPConfig.save
+ self.old_load = EIPConfig.load
+ self.old_si = SRPAuth.get_session_id
+
+ def tearDown(self):
+ EIPConfig.get_path_prefix = self.old_pp
+ EIPConfig.save = self.old_save
+ EIPConfig.load = self.old_load
+ SRPAuth.get_session_id = self.old_si
+
+ def _download_config_test_template(self, ifneeded, new):
+ """
+ All download config tests have the same structure, so this is
+ a parametrized test for that.
+
+ :param ifneeded: sets _download_if_needed
+ :type ifneeded: bool
+ :param new: if True uses time.time() as mtime for the mocked
+ eip-service file, otherwise it uses 100 (a really
+ old mtime)
+ :type new: float or int (will be coersed)
+ """
+ pc = ProviderConfig()
+ pc.get_domain = mock.MagicMock(
+ return_value="localhost:%s" % (self.https_port))
+ self.eb._provider_config = pc
+
+ pc.get_api_uri = mock.MagicMock(
+ return_value="https://%s" % (pc.get_domain()))
+ pc.get_api_version = mock.MagicMock(return_value="1")
+
+ # This is to ignore https checking, since it's not the point
+ # of this test
+ pc.get_ca_cert_path = mock.MagicMock(return_value=False)
+
+ path_prefix = tempfile.mkdtemp()
+ EIPConfig.get_path_prefix = mock.MagicMock(return_value=path_prefix)
+ EIPConfig.save = mock.MagicMock()
+ EIPConfig.load = mock.MagicMock()
+
+ self.eb._download_if_needed = ifneeded
+
+ provider_dir = os.path.join(EIPConfig.get_path_prefix(),
+ "leap",
+ "providers",
+ pc.get_domain())
+ mkdir_p(provider_dir)
+ eip_config_path = os.path.join(provider_dir,
+ "eip-service.json")
+
+ with open(eip_config_path, "w") as ec:
+ ec.write("A")
+
+ # set mtime to something really new
+ if new:
+ os.utime(eip_config_path, (-1, time.time()))
+ else:
+ os.utime(eip_config_path, (-1, 100))
+
+ @deferred()
+ def test_download_config_not_modified(self):
+ self._download_config_test_template(True, True)
+
+ d = threads.deferToThread(self.eb._download_config)
+
+ def check(*args):
+ self.assertFalse(self.eb._eip_config.save.called)
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_download_config_modified(self):
+ self._download_config_test_template(True, False)
+
+ d = threads.deferToThread(self.eb._download_config)
+
+ def check(*args):
+ self.assertTrue(self.eb._eip_config.save.called)
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_download_config_ignores_mtime(self):
+ self._download_config_test_template(False, True)
+
+ d = threads.deferToThread(self.eb._download_config)
+
+ def check(*args):
+ self.eb._eip_config.save.assert_called_once_with(
+ ["leap",
+ "providers",
+ self.eb._provider_config.get_domain(),
+ "eip-service.json"])
+ d.addCallback(check)
+ return d
+
+ def _download_certificate_test_template(self, ifneeded, createcert):
+ """
+ All download client certificate tests have the same structure,
+ so this is a parametrized test for that.
+
+ :param ifneeded: sets _download_if_needed
+ :type ifneeded: bool
+ :param createcert: if True it creates a dummy file to play the
+ part of a downloaded certificate
+ :type createcert: bool
+
+ :returns: the temp eip cert path and the dummy cert contents
+ :rtype: tuple of str, str
+ """
+ pc = ProviderConfig()
+ ec = EIPConfig()
+ self.eb._provider_config = pc
+ self.eb._eip_config = ec
+
+ pc.get_domain = mock.MagicMock(
+ return_value="localhost:%s" % (self.https_port))
+ pc.get_api_uri = mock.MagicMock(
+ return_value="https://%s" % (pc.get_domain()))
+ pc.get_api_version = mock.MagicMock(return_value="1")
+ pc.get_ca_cert_path = mock.MagicMock(return_value=False)
+
+ path_prefix = tempfile.mkdtemp()
+ EIPConfig.get_path_prefix = mock.MagicMock(return_value=path_prefix)
+ EIPConfig.save = mock.MagicMock()
+ EIPConfig.load = mock.MagicMock()
+
+ self.eb._download_if_needed = ifneeded
+
+ provider_dir = os.path.join(EIPConfig.get_path_prefix(),
+ "leap",
+ "providers",
+ "somedomain")
+ mkdir_p(provider_dir)
+ eip_cert_path = os.path.join(provider_dir,
+ "cert")
+
+ ec.get_client_cert_path = mock.MagicMock(
+ return_value=eip_cert_path)
+
+ cert_content = "A"
+ if createcert:
+ with open(eip_cert_path, "w") as ec:
+ ec.write(cert_content)
+
+ return eip_cert_path, cert_content
+
+ def test_download_client_certificate_not_modified(self):
+ cert_path, old_cert_content = self._download_certificate_test_template(
+ True, True)
+
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=False):
+ self.eb._download_client_certificates()
+ with open(cert_path, "r") as c:
+ self.assertEqual(c.read(), old_cert_content)
+
+ @deferred()
+ def test_download_client_certificate_old_cert(self):
+ cert_path, old_cert_content = self._download_certificate_test_template(
+ True, True)
+
+ def wrapper(*args):
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ with mock.patch('leap.common.certs.is_valid_pemfile',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ self.eb._download_client_certificates()
+
+ def check(*args):
+ with open(cert_path, "r") as c:
+ self.assertNotEqual(c.read(), old_cert_content)
+ d = threads.deferToThread(wrapper)
+ d.addCallback(check)
+
+ return d
+
+ @deferred()
+ def test_download_client_certificate_no_cert(self):
+ cert_path, _ = self._download_certificate_test_template(
+ True, False)
+
+ def wrapper(*args):
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=False):
+ with mock.patch('leap.common.certs.is_valid_pemfile',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ self.eb._download_client_certificates()
+
+ def check(*args):
+ self.assertTrue(os.path.exists(cert_path))
+ d = threads.deferToThread(wrapper)
+ d.addCallback(check)
+
+ return d
+
+ @deferred()
+ def test_download_client_certificate_force_not_valid(self):
+ cert_path, old_cert_content = self._download_certificate_test_template(
+ True, True)
+
+ def wrapper(*args):
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ with mock.patch('leap.common.certs.is_valid_pemfile',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ self.eb._download_client_certificates()
+
+ def check(*args):
+ with open(cert_path, "r") as c:
+ self.assertNotEqual(c.read(), old_cert_content)
+ d = threads.deferToThread(wrapper)
+ d.addCallback(check)
+
+ return d
+
+ @deferred()
+ def test_download_client_certificate_invalid_download(self):
+ cert_path, _ = self._download_certificate_test_template(
+ False, False)
+
+ def wrapper(*args):
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ with mock.patch('leap.common.certs.is_valid_pemfile',
+ new_callable=mock.MagicMock,
+ return_value=False):
+ with self.assertRaises(Exception):
+ self.eb._download_client_certificates()
+ d = threads.deferToThread(wrapper)
+
+ return d
+
+ @deferred()
+ def test_download_client_certificate_uses_session_id(self):
+ _, _ = self._download_certificate_test_template(
+ False, False)
+
+ SRPAuth.get_session_id = mock.MagicMock(return_value="1")
+
+ def check_cookie(*args, **kwargs):
+ cookies = kwargs.get("cookies", None)
+ self.assertEqual(cookies, {'_session_id': '1'})
+ return Response()
+
+ def wrapper(*args):
+ with mock.patch('leap.common.certs.should_redownload',
+ new_callable=mock.MagicMock,
+ return_value=False):
+ with mock.patch('leap.common.certs.is_valid_pemfile',
+ new_callable=mock.MagicMock,
+ return_value=True):
+ with mock.patch('requests.sessions.Session.get',
+ new_callable=mock.MagicMock,
+ side_effect=check_cookie):
+ with mock.patch('requests.models.Response.content',
+ new_callable=mock.PropertyMock,
+ return_value="A"):
+ self.eb._download_client_certificates()
+
+ d = threads.deferToThread(wrapper)
+
+ return d
+
+ @deferred()
+ def test_run_eip_setup_checks(self):
+ self.eb._download_config = mock.MagicMock()
+ self.eb._download_client_certificates = mock.MagicMock()
+
+ d = self.eb.run_eip_setup_checks(ProviderConfig())
+
+ def check(*args):
+ self.eb._download_config.assert_called_once_with()
+ self.eb._download_client_certificates.assert_called_once_with(None)
+ d.addCallback(check)
+ return d
diff --git a/src/leap/services/eip/tests/test_eipconfig.py b/src/leap/services/eip/tests/test_eipconfig.py
new file mode 100644
index 00000000..8b746b78
--- /dev/null
+++ b/src/leap/services/eip/tests/test_eipconfig.py
@@ -0,0 +1,313 @@
+# -*- coding: utf-8 -*-
+# test_eipconfig.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for eipconfig
+"""
+import copy
+import json
+import os
+import unittest
+
+from leap.common.testing.basetest import BaseLeapTest
+from leap.services.eip.eipconfig import EIPConfig
+from leap.config.providerconfig import ProviderConfig
+
+from mock import Mock
+
+
+sample_config = {
+ "gateways": [
+ {
+ "capabilities": {
+ "adblock": False,
+ "filter_dns": True,
+ "limited": True,
+ "ports": [
+ "1194",
+ "443",
+ "53",
+ "80"],
+ "protocols": [
+ "tcp",
+ "udp"],
+ "transport": ["openvpn"],
+ "user_ips": False},
+ "host": "host.dev.example.org",
+ "ip_address": "11.22.33.44",
+ "location": "cyberspace"
+ }, {
+ "capabilities": {
+ "adblock": False,
+ "filter_dns": True,
+ "limited": True,
+ "ports": [
+ "1194",
+ "443",
+ "53",
+ "80"],
+ "protocols": [
+ "tcp",
+ "udp"],
+ "transport": ["openvpn"],
+ "user_ips": False},
+ "host": "host2.dev.example.org",
+ "ip_address": "22.33.44.55",
+ "location": "cyberspace"
+ }
+ ],
+ "locations": {
+ "ankara": {
+ "country_code": "XX",
+ "hemisphere": "S",
+ "name": "Antarctica",
+ "timezone": "+2"
+ },
+ "cyberspace": {
+ "country_code": "XX",
+ "hemisphere": "X",
+ "name": "outer space",
+ "timezone": ""
+ }
+ },
+ "openvpn_configuration": {
+ "auth": "SHA1",
+ "cipher": "AES-128-CBC",
+ "tls-cipher": "DHE-RSA-AES128-SHA"
+ },
+ "serial": 1,
+ "version": 1
+}
+
+
+class EIPConfigTest(BaseLeapTest):
+
+ __name__ = "eip_config_tests"
+
+ maxDiff = None
+
+ def setUp(self):
+ self._old_ospath_exists = os.path.exists
+
+ def tearDown(self):
+ os.path.exists = self._old_ospath_exists
+
+ def _write_config(self, data):
+ """
+ Helper to write some data to a temp config file.
+
+ :param data: data to be used to save in the config file.
+ :data type: dict (valid json)
+ """
+ self.configfile = os.path.join(self.tempdir, "eipconfig.json")
+ conf = open(self.configfile, "w")
+ conf.write(json.dumps(data))
+ conf.close()
+
+ def _get_eipconfig(self, fromfile=True, data=sample_config):
+ """
+ Helper that returns an EIPConfig object using the data parameter
+ or a sample data.
+
+ :param fromfile: sets if we should use a file or a string
+ :fromfile type: bool
+ :param data: sets the data to be used to load in the EIPConfig object
+ :data type: dict (valid json)
+ :rtype: EIPConfig
+ """
+ config = EIPConfig()
+
+ loaded = False
+ if fromfile:
+ self._write_config(data)
+ loaded = config.load(self.configfile, relative=False)
+ else:
+ json_string = json.dumps(data)
+ loaded = config.load(data=json_string)
+
+ if not loaded:
+ return None
+
+ return config
+
+ def test_loads_from_file(self):
+ config = self._get_eipconfig()
+ self.assertIsNotNone(config)
+
+ def test_loads_from_data(self):
+ config = self._get_eipconfig(fromfile=False)
+ self.assertIsNotNone(config)
+
+ def test_load_valid_config_from_file(self):
+ config = self._get_eipconfig()
+ self.assertIsNotNone(config)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ sample_ip = sample_config["gateways"][0]["ip_address"]
+ self.assertEqual(
+ config.get_gateway_ip(),
+ sample_ip)
+ self.assertEqual(config.get_version(), sample_config["version"])
+ self.assertEqual(config.get_serial(), sample_config["serial"])
+ self.assertEqual(config.get_gateways(), sample_config["gateways"])
+ self.assertEqual(config.get_locations(), sample_config["locations"])
+ self.assertEqual(config.get_clusters(), None)
+
+ def test_load_valid_config_from_data(self):
+ config = self._get_eipconfig(fromfile=False)
+ self.assertIsNotNone(config)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ sample_ip = sample_config["gateways"][0]["ip_address"]
+ self.assertEqual(
+ config.get_gateway_ip(),
+ sample_ip)
+
+ self.assertEqual(config.get_version(), sample_config["version"])
+ self.assertEqual(config.get_serial(), sample_config["serial"])
+ self.assertEqual(config.get_gateways(), sample_config["gateways"])
+ self.assertEqual(config.get_locations(), sample_config["locations"])
+ self.assertEqual(config.get_clusters(), None)
+
+ def test_sanitize_extra_parameters(self):
+ data = copy.deepcopy(sample_config)
+ data['openvpn_configuration']["extra_param"] = "FOO"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ def test_sanitize_non_allowed_chars(self):
+ data = copy.deepcopy(sample_config)
+ data['openvpn_configuration']["auth"] = "SHA1;"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ data = copy.deepcopy(sample_config)
+ data['openvpn_configuration']["auth"] = "SHA1>`&|"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ def test_sanitize_lowercase(self):
+ data = copy.deepcopy(sample_config)
+ data['openvpn_configuration']["auth"] = "shaSHA1"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ sample_config["openvpn_configuration"])
+
+ def test_all_characters_invalid(self):
+ data = copy.deepcopy(sample_config)
+ data['openvpn_configuration']["auth"] = "sha&*!@#;"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(
+ config.get_openvpn_configuration(),
+ {'cipher': 'AES-128-CBC',
+ 'tls-cipher': 'DHE-RSA-AES128-SHA'})
+
+ def test_sanitize_bad_ip(self):
+ data = copy.deepcopy(sample_config)
+ data['gateways'][0]["ip_address"] = "11.22.33.44;"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(config.get_gateway_ip(), None)
+
+ data = copy.deepcopy(sample_config)
+ data['gateways'][0]["ip_address"] = "11.22.33.44`"
+ config = self._get_eipconfig(data=data)
+
+ self.assertEqual(config.get_gateway_ip(), None)
+
+ def test_default_gateway_on_unknown_index(self):
+ config = self._get_eipconfig()
+ sample_ip = sample_config["gateways"][0]["ip_address"]
+ self.assertEqual(config.get_gateway_ip(999), sample_ip)
+
+ def test_get_gateway_by_index(self):
+ config = self._get_eipconfig()
+ sample_ip_0 = sample_config["gateways"][0]["ip_address"]
+ sample_ip_1 = sample_config["gateways"][1]["ip_address"]
+ self.assertEqual(config.get_gateway_ip(0), sample_ip_0)
+ self.assertEqual(config.get_gateway_ip(1), sample_ip_1)
+
+ def test_get_client_cert_path_as_expected(self):
+ config = self._get_eipconfig()
+ config.get_path_prefix = Mock(return_value='test')
+
+ provider_config = ProviderConfig()
+
+ # mock 'get_domain' so we don't need to load a config
+ provider_domain = 'test.provider.com'
+ provider_config.get_domain = Mock(return_value=provider_domain)
+
+ expected_path = os.path.join('test', 'leap', 'providers',
+ provider_domain, 'keys', 'client',
+ 'openvpn.pem')
+
+ # mock 'os.path.exists' so we don't get an error for unexisting file
+ os.path.exists = Mock(return_value=True)
+ cert_path = config.get_client_cert_path(provider_config)
+
+ self.assertEqual(cert_path, expected_path)
+
+ def test_get_client_cert_path_about_to_download(self):
+ config = self._get_eipconfig()
+ config.get_path_prefix = Mock(return_value='test')
+
+ provider_config = ProviderConfig()
+
+ # mock 'get_domain' so we don't need to load a config
+ provider_domain = 'test.provider.com'
+ provider_config.get_domain = Mock(return_value=provider_domain)
+
+ expected_path = os.path.join('test', 'leap', 'providers',
+ provider_domain, 'keys', 'client',
+ 'openvpn.pem')
+
+ cert_path = config.get_client_cert_path(
+ provider_config, about_to_download=True)
+
+ self.assertEqual(cert_path, expected_path)
+
+ def test_get_client_cert_path_fails(self):
+ config = self._get_eipconfig()
+ provider_config = ProviderConfig()
+
+ # mock 'get_domain' so we don't need to load a config
+ provider_domain = 'test.provider.com'
+ provider_config.get_domain = Mock(return_value=provider_domain)
+
+ with self.assertRaises(AssertionError):
+ config.get_client_cert_path(provider_config)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/services/eip/tests/test_providerbootstrapper.py b/src/leap/services/eip/tests/test_providerbootstrapper.py
new file mode 100644
index 00000000..cd740793
--- /dev/null
+++ b/src/leap/services/eip/tests/test_providerbootstrapper.py
@@ -0,0 +1,504 @@
+# -*- coding: utf-8 -*-
+# test_providerbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+
+"""
+Tests for the Provider Boostrapper checks
+
+These will be whitebox tests since we want to make sure the private
+implementation is checking what we expect.
+"""
+
+import os
+import mock
+import socket
+import stat
+import tempfile
+import time
+import requests
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from nose.twistedtools import deferred, reactor
+from twisted.internet import threads
+from requests.models import Response
+
+from leap.common.testing.https_server import where
+from leap.common.testing.basetest import BaseLeapTest
+from leap.services.eip.providerbootstrapper import ProviderBootstrapper
+from leap.services.eip.providerbootstrapper import UnsupportedProviderAPI
+from leap.provider.supportedapis import SupportedAPIs
+from leap.config.providerconfig import ProviderConfig
+from leap.crypto.tests import fake_provider
+from leap.common.files import mkdir_p
+
+
+class ProviderBootstrapperTest(BaseLeapTest):
+ def setUp(self):
+ self.pb = ProviderBootstrapper()
+
+ def tearDown(self):
+ pass
+
+ def test_name_resolution_check(self):
+ # Something highly likely to success
+ self.pb._domain = "google.com"
+ self.pb._check_name_resolution()
+ # Something highly likely to fail
+ self.pb._domain = "uquhqweuihowquie.abc.def"
+ with self.assertRaises(socket.gaierror):
+ self.pb._check_name_resolution()
+
+ @deferred()
+ def test_run_provider_select_checks(self):
+ self.pb._check_name_resolution = mock.MagicMock()
+ self.pb._check_https = mock.MagicMock()
+ self.pb._download_provider_info = mock.MagicMock()
+
+ d = self.pb.run_provider_select_checks("somedomain")
+
+ def check(*args):
+ self.pb._check_name_resolution.assert_called_once_with()
+ self.pb._check_https.assert_called_once_with(None)
+ self.pb._download_provider_info.assert_called_once_with(None)
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_run_provider_setup_checks(self):
+ self.pb._download_ca_cert = mock.MagicMock()
+ self.pb._check_ca_fingerprint = mock.MagicMock()
+ self.pb._check_api_certificate = mock.MagicMock()
+
+ d = self.pb.run_provider_setup_checks(ProviderConfig())
+
+ def check(*args):
+ self.pb._download_ca_cert.assert_called_once_with()
+ self.pb._check_ca_fingerprint.assert_called_once_with(None)
+ self.pb._check_api_certificate.assert_called_once_with(None)
+ d.addCallback(check)
+ return d
+
+ def test_should_proceed_cert(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where("cacert.pem"))
+
+ self.pb._download_if_needed = False
+ self.assertTrue(self.pb._should_proceed_cert())
+
+ self.pb._download_if_needed = True
+ self.assertFalse(self.pb._should_proceed_cert())
+
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where("somefilethatdoesntexist.pem"))
+ self.assertTrue(self.pb._should_proceed_cert())
+
+ def _check_download_ca_cert(self, should_proceed):
+ """
+ Helper to check different paths easily for the download ca
+ cert check
+
+ :param should_proceed: sets the _should_proceed_cert in the
+ provider bootstrapper being tested
+ :type should_proceed: bool
+
+ :returns: The contents of the certificate, the expected
+ content depending on should_proceed, and the mode of
+ the file to be checked by the caller
+ :rtype: tuple of str, str, int
+ """
+ old_content = "NOT THE NEW CERT"
+ new_content = "NEW CERT"
+ new_cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(new_cert_path, "w") as c:
+ c.write(old_content)
+
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=new_cert_path)
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(
+ return_value=should_proceed)
+
+ read = None
+ content_to_check = None
+ mode = None
+
+ with mock.patch('requests.models.Response.content',
+ new_callable=mock.PropertyMock) as \
+ content:
+ content.return_value = new_content
+ response_obj = Response()
+ response_obj.raise_for_status = mock.MagicMock()
+
+ self.pb._session.get = mock.MagicMock(return_value=response_obj)
+ self.pb._download_ca_cert()
+ with open(new_cert_path, "r") as nc:
+ read = nc.read()
+ if should_proceed:
+ content_to_check = new_content
+ else:
+ content_to_check = old_content
+ mode = stat.S_IMODE(os.stat(new_cert_path).st_mode)
+
+ os.unlink(new_cert_path)
+ return read, content_to_check, mode
+
+ def test_download_ca_cert_no_saving(self):
+ read, expected_read, mode = self._check_download_ca_cert(False)
+ self.assertEqual(read, expected_read)
+ self.assertEqual(mode, int("600", 8))
+
+ def test_download_ca_cert_saving(self):
+ read, expected_read, mode = self._check_download_ca_cert(True)
+ self.assertEqual(read, expected_read)
+ self.assertEqual(mode, int("600", 8))
+
+ def test_check_ca_fingerprint_skips(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value="")
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=False)
+
+ self.pb._check_ca_fingerprint()
+ self.assertFalse(self.pb._provider_config.
+ get_ca_cert_fingerprint.called)
+
+ def test_check_ca_cert_fingerprint_raises_bad_format(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value="wrongfprformat!!")
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ with self.assertRaises(AssertionError):
+ self.pb._check_ca_fingerprint()
+
+ # This two hashes different in the last byte, but that's good enough
+ # for the tests
+ KNOWN_BAD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034efe" \
+ "7dd1b910062ca323eb4da5c7f"
+ KNOWN_GOOD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034ef" \
+ "e7dd1b910062ca323eb4da5c7e"
+ KNOWN_GOOD_CERT = """
+-----BEGIN CERTIFICATE-----
+MIIFbzCCA1egAwIBAgIBATANBgkqhkiG9w0BAQ0FADBKMRgwFgYDVQQDDA9CaXRt
+YXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNVBAsME2h0dHBzOi8v
+Yml0bWFzay5uZXQwHhcNMTIxMTA2MDAwMDAwWhcNMjIxMTA2MDAwMDAwWjBKMRgw
+FgYDVQQDDA9CaXRtYXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNV
+BAsME2h0dHBzOi8vYml0bWFzay5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAw
+ggIKAoICAQC1eV4YvayaU+maJbWrD4OHo3d7S1BtDlcvkIRS1Fw3iYDjsyDkZxai
+dHp4EUasfNQ+EVtXUvtk6170EmLco6Elg8SJBQ27trE6nielPRPCfX3fQzETRfvB
+7tNvGw4Jn2YKiYoMD79kkjgyZjkJ2r/bEHUSevmR09BRp86syHZerdNGpXYhcQ84
+CA1+V+603GFIHnrP+uQDdssW93rgDNYu+exT+Wj6STfnUkugyjmPRPjL7wh0tzy+
+znCeLl4xiV3g9sjPnc7r2EQKd5uaTe3j71sDPF92KRk0SSUndREz+B1+Dbe/RGk4
+MEqGFuOzrtsgEhPIX0hplhb0Tgz/rtug+yTT7oJjBa3u20AAOQ38/M99EfdeJvc4
+lPFF1XBBLh6X9UKF72an2NuANiX6XPySnJgZ7nZ09RiYZqVwu/qt3DfvLfhboq+0
+bQvLUPXrVDr70onv5UDjpmEA/cLmaIqqrduuTkFZOym65/PfAPvpGnt7crQj/Ibl
+DEDYZQmP7AS+6zBjoOzNjUGE5r40zWAR1RSi7zliXTu+yfsjXUIhUAWmYR6J3KxB
+lfsiHBQ+8dn9kC3YrUexWoOqBiqJOAJzZh5Y1tqgzfh+2nmHSB2dsQRs7rDRRlyy
+YMbkpzL9ZsOUO2eTP1mmar6YjCN+rggYjRrX71K2SpBG6b1zZxOG+wIDAQABo2Aw
+XjAdBgNVHQ4EFgQUuYGDLL2sswnYpHHvProt1JU+D48wDgYDVR0PAQH/BAQDAgIE
+MAwGA1UdEwQFMAMBAf8wHwYDVR0jBBgwFoAUuYGDLL2sswnYpHHvProt1JU+D48w
+DQYJKoZIhvcNAQENBQADggIBADeG67vaFcbITGpi51264kHPYPEWaXUa5XYbtmBl
+cXYyB6hY5hv/YNuVGJ1gWsDmdeXEyj0j2icGQjYdHRfwhrbEri+h1EZOm1cSBDuY
+k/P5+ctHyOXx8IE79DBsZ6IL61UKIaKhqZBfLGYcWu17DVV6+LT+AKtHhOrv3TSj
+RnAcKnCbKqXLhUPXpK0eTjPYS2zQGQGIhIy9sQXVXJJJsGrPgMxna1Xw2JikBOCG
+htD/JKwt6xBmNwktH0GI/LVtVgSp82Clbn9C4eZN9E5YbVYjLkIEDhpByeC71QhX
+EIQ0ZR56bFuJA/CwValBqV/G9gscTPQqd+iETp8yrFpAVHOW+YzSFbxjTEkBte1J
+aF0vmbqdMAWLk+LEFPQRptZh0B88igtx6tV5oVd+p5IVRM49poLhuPNJGPvMj99l
+mlZ4+AeRUnbOOeAEuvpLJbel4rhwFzmUiGoeTVoPZyMevWcVFq6BMkS+jRR2w0jK
+G6b0v5XDHlcFYPOgUrtsOBFJVwbutLvxdk6q37kIFnWCd8L3kmES5q4wjyFK47Co
+Ja8zlx64jmMZPg/t3wWqkZgXZ14qnbyG5/lGsj5CwVtfDljrhN0oCWK1FZaUmW3d
+69db12/g4f6phldhxiWuGC/W6fCW5kre7nmhshcltqAJJuU47iX+DarBFiIj816e
+yV8e
+-----END CERTIFICATE-----
+"""
+
+ def _prepare_provider_config_with(self, cert_path, cert_hash):
+ """
+ Mocks the provider config to give the cert_path and cert_hash
+ specified
+
+ :param cert_path: path for the certificate
+ :type cert_path: str
+ :param cert_hash: hash for the certificate as it would appear
+ in the provider config json
+ :type cert_hash: str
+ """
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value=cert_hash)
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=cert_path)
+ self.pb._domain = "somedomain"
+
+ def test_check_ca_fingerprint_checksout(self):
+ cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(cert_path, "w") as c:
+ c.write(self.KNOWN_GOOD_CERT)
+
+ self._prepare_provider_config_with(cert_path, self.KNOWN_GOOD_HASH)
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ self.pb._check_ca_fingerprint()
+
+ os.unlink(cert_path)
+
+ def test_check_ca_fingerprint_fails(self):
+ cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(cert_path, "w") as c:
+ c.write(self.KNOWN_GOOD_CERT)
+
+ self._prepare_provider_config_with(cert_path, self.KNOWN_BAD_HASH)
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ with self.assertRaises(AssertionError):
+ self.pb._check_ca_fingerprint()
+
+ os.unlink(cert_path)
+
+
+###############################################################################
+# Tests with a fake provider #
+###############################################################################
+
+class ProviderBootstrapperActiveTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ factory = fake_provider.get_provider_factory()
+ http = reactor.listenTCP(8002, factory)
+ https = reactor.listenSSL(
+ 0, factory,
+ fake_provider.OpenSSLServerContextFactory())
+ get_port = lambda p: p.getHost().port
+ cls.http_port = get_port(http)
+ cls.https_port = get_port(https)
+
+ def setUp(self):
+ self.pb = ProviderBootstrapper()
+
+ # At certain points we are going to be replacing these methods
+ # directly in ProviderConfig to be able to catch calls from
+ # new ProviderConfig objects inside the methods tested. We
+ # need to save the old implementation and restore it in
+ # tearDown so we are sure everything is as expected for each
+ # test. If we do it inside each specific test, a failure in
+ # the test will leave the implementation with the mock.
+ self.old_gpp = ProviderConfig.get_path_prefix
+ self.old_load = ProviderConfig.load
+ self.old_save = ProviderConfig.save
+ self.old_api_version = ProviderConfig.get_api_version
+
+ def tearDown(self):
+ ProviderConfig.get_path_prefix = self.old_gpp
+ ProviderConfig.load = self.old_load
+ ProviderConfig.save = self.old_save
+ ProviderConfig.get_api_version = self.old_api_version
+
+ def test_check_https_succeeds(self):
+ # XXX: Need a proper CA signed cert to test this
+ pass
+
+ @deferred()
+ def test_check_https_fails(self):
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+
+ def check(*args):
+ with self.assertRaises(requests.exceptions.SSLError):
+ self.pb._check_https()
+ return threads.deferToThread(check)
+
+ @deferred()
+ def test_second_check_https_fails(self):
+ self.pb._domain = "localhost:1234"
+
+ def check(*args):
+ with self.assertRaises(Exception):
+ self.pb._check_https()
+ return threads.deferToThread(check)
+
+ @deferred()
+ def test_check_https_succeeds_if_danger(self):
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+ self.pb._bypass_checks = True
+
+ def check(*args):
+ self.pb._check_https()
+
+ return threads.deferToThread(check)
+
+ def _setup_provider_config_with(self, api, path_prefix):
+ """
+ Sets up the ProviderConfig with mocks for the path prefix, the
+ api returned and load/save methods.
+ It modifies ProviderConfig directly instead of an object
+ because the object used is created in the method itself and we
+ cannot control that.
+
+ :param api: API to return
+ :type api: str
+ :param path_prefix: path prefix to be used when calculating
+ paths
+ :type path_prefix: str
+ """
+ ProviderConfig.get_path_prefix = mock.MagicMock(
+ return_value=path_prefix)
+ ProviderConfig.get_api_version = mock.MagicMock(
+ return_value=api)
+ ProviderConfig.load = mock.MagicMock()
+ ProviderConfig.save = mock.MagicMock()
+
+ def _setup_providerbootstrapper(self, ifneeded):
+ """
+ Sets the provider bootstrapper's domain to
+ localhost:https_port, sets it to bypass https checks and sets
+ the download if needed based on the ifneeded value.
+
+ :param ifneeded: Value for _download_if_needed
+ :type ifneeded: bool
+ """
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+ self.pb._bypass_checks = True
+ self.pb._download_if_needed = ifneeded
+
+ def _produce_dummy_provider_json(self):
+ """
+ Creates a dummy provider json on disk in order to test
+ behaviour around it (download if newer online, etc)
+
+ :returns: the provider.json path used
+ :rtype: str
+ """
+ provider_dir = os.path.join(ProviderConfig()
+ .get_path_prefix(),
+ "leap",
+ "providers",
+ self.pb._domain)
+ mkdir_p(provider_dir)
+ provider_path = os.path.join(provider_dir,
+ "provider.json")
+
+ with open(provider_path, "w") as p:
+ p.write("A")
+ return provider_path
+
+ def test_download_provider_info_not_modified(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+ provider_path = self._produce_dummy_provider_json()
+
+ # set mtime to something really new
+ os.utime(provider_path, (-1, time.time()))
+
+ self.pb._download_provider_info()
+ # we check that it doesn't do anything with the provider
+ # config, because it's new enough
+ self.assertFalse(ProviderConfig.load.called)
+ self.assertFalse(ProviderConfig.save.called)
+
+ def test_download_provider_info_modified(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+ provider_path = self._produce_dummy_provider_json()
+
+ # set mtime to something really old
+ os.utime(provider_path, (-1, 100))
+
+ self.pb._download_provider_info()
+ self.assertTrue(ProviderConfig.load.called)
+ self.assertTrue(ProviderConfig.save.called)
+
+ def test_download_provider_info_unsupported_api_raises(self):
+ self._setup_provider_config_with("9999999", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(False)
+ self._produce_dummy_provider_json()
+
+ with self.assertRaises(UnsupportedProviderAPI):
+ self.pb._download_provider_info()
+
+ def test_download_provider_info_unsupported_api(self):
+ self._setup_provider_config_with(SupportedAPIs.SUPPORTED_APIS[0],
+ tempfile.mkdtemp())
+ self._setup_providerbootstrapper(False)
+ self._produce_dummy_provider_json()
+
+ self.pb._download_provider_info()
+
+ def test_check_api_certificate_skips(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._provider_config.get_api_uri = mock.MagicMock(
+ return_value="api.uri")
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value="/cert/path")
+ self.pb._session.get = mock.MagicMock(return_value=Response())
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=False)
+ self.pb._check_api_certificate()
+ self.assertFalse(self.pb._session.get.called)
+
+ @deferred()
+ def test_check_api_certificate_fails(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._provider_config.get_api_uri = mock.MagicMock(
+ return_value="https://localhost:%s" % (self.https_port,))
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=os.path.join(
+ os.path.split(__file__)[0],
+ "wrongcert.pem"))
+ self.pb._provider_config.get_api_version = mock.MagicMock(
+ return_value="1")
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ def check(*args):
+ with self.assertRaises(requests.exceptions.SSLError):
+ self.pb._check_api_certificate()
+ d = threads.deferToThread(check)
+ return d
+
+ @deferred()
+ def test_check_api_certificate_succeeds(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._provider_config.get_api_uri = mock.MagicMock(
+ return_value="https://localhost:%s" % (self.https_port,))
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where('cacert.pem'))
+ self.pb._provider_config.get_api_version = mock.MagicMock(
+ return_value="1")
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ def check(*args):
+ self.pb._check_api_certificate()
+ d = threads.deferToThread(check)
+ return d
diff --git a/src/leap/services/eip/tests/test_vpngatewayselector.py b/src/leap/services/eip/tests/test_vpngatewayselector.py
new file mode 100644
index 00000000..c90681d7
--- /dev/null
+++ b/src/leap/services/eip/tests/test_vpngatewayselector.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# test_vpngatewayselector.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+tests for vpngatewayselector
+"""
+
+import unittest
+
+from leap.services.eip.eipconfig import EIPConfig, VPNGatewaySelector
+from leap.common.testing.basetest import BaseLeapTest
+from mock import Mock
+
+
+sample_gateways = [
+ {u'host': u'gateway1.com',
+ u'ip_address': u'1.2.3.4',
+ u'location': u'location1'},
+ {u'host': u'gateway2.com',
+ u'ip_address': u'2.3.4.5',
+ u'location': u'location2'},
+ {u'host': u'gateway3.com',
+ u'ip_address': u'3.4.5.6',
+ u'location': u'location3'},
+ {u'host': u'gateway4.com',
+ u'ip_address': u'4.5.6.7',
+ u'location': u'location4'}
+]
+
+sample_gateways_no_location = [
+ {u'host': u'gateway1.com',
+ u'ip_address': u'1.2.3.4'},
+ {u'host': u'gateway2.com',
+ u'ip_address': u'2.3.4.5'},
+ {u'host': u'gateway3.com',
+ u'ip_address': u'3.4.5.6'}
+]
+
+sample_locations = {
+ u'location1': {u'timezone': u'2'},
+ u'location2': {u'timezone': u'-7'},
+ u'location3': {u'timezone': u'-4'},
+ u'location4': {u'timezone': u'+13'}
+}
+
+# 0 is not used, only for indexing from 1 in tests
+ips = (0, u'1.2.3.4', u'2.3.4.5', u'3.4.5.6', u'4.5.6.7')
+
+
+class VPNGatewaySelectorTest(BaseLeapTest):
+ """
+ VPNGatewaySelector's tests.
+ """
+ def setUp(self):
+ self.eipconfig = EIPConfig()
+ self.eipconfig.get_gateways = Mock(return_value=sample_gateways)
+ self.eipconfig.get_locations = Mock(return_value=sample_locations)
+
+ def tearDown(self):
+ pass
+
+ def test_get_no_gateways(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig)
+ self.eipconfig.get_gateways = Mock(return_value=[])
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [])
+
+ def test_get_gateway_with_no_locations(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig)
+ self.eipconfig.get_gateways = Mock(
+ return_value=sample_gateways_no_location)
+ self.eipconfig.get_locations = Mock(return_value=[])
+ gateways = gateway_selector.get_gateways()
+ gateways_default_order = [
+ sample_gateways[0]['ip_address'],
+ sample_gateways[1]['ip_address'],
+ sample_gateways[2]['ip_address']
+ ]
+ self.assertEqual(gateways, gateways_default_order)
+
+ def test_correct_order_gmt(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, 0)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[1], ips[3], ips[2], ips[4]])
+
+ def test_correct_order_gmt_minus_3(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, -3)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[3], ips[2], ips[1], ips[4]])
+
+ def test_correct_order_gmt_minus_7(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, -7)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[2], ips[3], ips[4], ips[1]])
+
+ def test_correct_order_gmt_plus_5(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, 5)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[1], ips[4], ips[3], ips[2]])
+
+ def test_correct_order_gmt_plus_12(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, 12)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]])
+
+ def test_correct_order_gmt_minus_11(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, -11)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]])
+
+ def test_correct_order_gmt_plus_14(self):
+ gateway_selector = VPNGatewaySelector(self.eipconfig, 14)
+ gateways = gateway_selector.get_gateways()
+ self.assertEqual(gateways, [ips[4], ips[2], ips[3], ips[1]])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/services/eip/tests/wrongcert.pem b/src/leap/services/eip/tests/wrongcert.pem
new file mode 100644
index 00000000..e6cff38a
--- /dev/null
+++ b/src/leap/services/eip/tests/wrongcert.pem
@@ -0,0 +1,33 @@
+-----BEGIN CERTIFICATE-----
+MIIFtTCCA52gAwIBAgIJAIWZus5EIXNtMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
+BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
+aWRnaXRzIFB0eSBMdGQwHhcNMTMwNjI1MTc0NjExWhcNMTgwNjI1MTc0NjExWjBF
+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
+ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC
+CgKCAgEA2ObM7ESjyuxFZYD/Y68qOPQgjgggW+cdXfBpU2p4n7clsrUeMhWdW40Y
+77Phzor9VOeqs3ZpHuyLzsYVp/kFDm8tKyo2ah5fJwzL0VCSLYaZkUQQ7GNUmTCk
+furaxl8cQx/fg395V7/EngsS9B3/y5iHbctbA4MnH3jaotO5EGeo6hw7/eyCotQ9
+KbBV9GJMcY94FsXBCmUB+XypKklWTLhSaS6Cu4Fo8YLW6WmcnsyEOGS2F7WVf5at
+7CBWFQZHaSgIBLmc818/mDYCnYmCVMFn/6Ndx7V2NTlz+HctWrQn0dmIOnCUeCwS
+wXq9PnBR1rSx/WxwyF/WpyjOFkcIo7vm72kS70pfrYsXcZD4BQqkXYj3FyKnPt3O
+ibLKtCxL8/83wOtErPcYpG6LgFkgAAlHQ9MkUi5dbmjCJtpqQmlZeK1RALdDPiB3
+K1KZimrGsmcE624dJxUIOJJpuwJDy21F8kh5ZAsAtE1prWETrQYNElNFjQxM83rS
+ZR1Ql2MPSB4usEZT57+KvpEzlOnAT3elgCg21XrjSFGi14hCEao4g2OEZH5GAwm5
+frf6UlSRZ/g3tLTfI8Hv1prw15W2qO+7q7SBAplTODCRk+Yb0YoA2mMM/QXBUcXs
+vKEDLSSxzNIBi3T62l39RB/ml+gPKo87ZMDivex1ZhrcJc3Yu3sCAwEAAaOBpzCB
+pDAdBgNVHQ4EFgQUPjE+4pun+8FreIdpoR8v6N7xKtUwdQYDVR0jBG4wbIAUPjE+
+4pun+8FreIdpoR8v6N7xKtWhSaRHMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpT
+b21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGSCCQCF
+mbrORCFzbTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQCpvCPdtvXJ
+muTj379TZuCJs7/l0FhA7AHa1WAlHjsXHaA7N0+3ZWAbdtXDsowal6S+ldgU/kfV
+Lq7NrRq+amJWC7SYj6cvVwhrSwSvu01fe/TWuOzHrRv1uTfJ/VXLonVufMDd9opo
+bhqYxMaxLdIx6t/MYmZH4Wpiq0yfZuv//M8i7BBl/qvaWbLhg0yVAKRwjFvf59h6
+6tRFCLddELOIhLDQtk8zMbioPEbfAlKdwwP8kYGtDGj6/9/YTd/oTKRdgHuwyup3
+m0L20Y6LddC+tb0WpK5EyrNbCbEqj1L4/U7r6f/FKNA3bx6nfdXbscaMfYonKAKg
+1cRrRg45sErmCz0QyTnWzXyvbjR4oQRzyW3kJ1JZudZ+AwOi00J5FYa3NiLuxl1u
+gIGKWSrASQWhEdpa1nlCgX7PhdaQgYjEMpQvA0GCA0OF5JDu8en1yZqsOt1hCLIN
+lkz/5jKPqrclY5hV99bE3hgCHRmIPNHCZG3wbZv2yJKxJX1YLMmQwAmSh2N7YwGG
+yXRvCxQs5ChPHyRairuf/5MZCZnSVb45ppTVuNUijsbflKRUgfj/XvfqQ22f+C9N
+Om2dmNvAiS2TOIfuP47CF2OUa5q4plUwmr+nyXQGM0SIoHNCj+MBdFfb3oxxAtI+
+SLhbnzQv5e84Doqz3YF0XW8jyR7q8GFLNA==
+-----END CERTIFICATE-----
diff --git a/src/leap/services/eip/udstelnet.py b/src/leap/services/eip/udstelnet.py
new file mode 100644
index 00000000..e6c82350
--- /dev/null
+++ b/src/leap/services/eip/udstelnet.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+# udstelnet.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import os
+import socket
+import telnetlib
+
+
+class ConnectionRefusedError(Exception):
+ pass
+
+
+class MissingSocketError(Exception):
+ pass
+
+
+class UDSTelnet(telnetlib.Telnet):
+ """
+ A telnet-alike class, that can listen on unix domain sockets
+ """
+
+ def open(self, host, port=23, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ """
+ Connect to a host. If port is 'unix', it will open a
+ connection over unix docmain sockets.
+
+ The optional second argument is the port number, which
+ defaults to the standard telnet port (23).
+ Don't try to reopen an already connected instance.
+ """
+ self.eof = 0
+ self.host = host
+ self.port = port
+ self.timeout = timeout
+
+ if self.port == "unix":
+ # unix sockets spoken
+ if not os.path.exists(self.host):
+ raise MissingSocketError()
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ self.sock.connect(self.host)
+ except socket.error:
+ raise ConnectionRefusedError()
+ else:
+ self.sock = socket.create_connection((host, port), timeout)
diff --git a/src/leap/services/eip/vpnlaunchers.py b/src/leap/services/eip/vpnlaunchers.py
new file mode 100644
index 00000000..570a7893
--- /dev/null
+++ b/src/leap/services/eip/vpnlaunchers.py
@@ -0,0 +1,774 @@
+# -*- coding: utf-8 -*-
+# vpnlaunchers.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+Platform dependant VPN launchers
+"""
+import commands
+import logging
+import getpass
+import os
+import platform
+import subprocess
+import stat
+try:
+ import grp
+except ImportError:
+ pass # ignore, probably windows
+
+from abc import ABCMeta, abstractmethod
+from functools import partial
+
+from leap.common.check import leap_assert, leap_assert_type
+from leap.common.files import which
+from leap.config.providerconfig import ProviderConfig
+from leap.services.eip.eipconfig import EIPConfig, VPNGatewaySelector
+from leap.util import first
+
+logger = logging.getLogger(__name__)
+
+
+class VPNLauncherException(Exception):
+ pass
+
+
+class OpenVPNNotFoundException(VPNLauncherException):
+ pass
+
+
+class EIPNoPolkitAuthAgentAvailable(VPNLauncherException):
+ pass
+
+
+class EIPNoPkexecAvailable(VPNLauncherException):
+ pass
+
+
+class VPNLauncher:
+ """
+ Abstract launcher class
+ """
+ __metaclass__ = ABCMeta
+
+ UPDOWN_FILES = None
+ OTHER_FILES = None
+
+ @abstractmethod
+ def get_vpn_command(self, eipconfig=None, providerconfig=None,
+ socket_host=None, socket_port=None):
+ """
+ Returns the platform dependant vpn launching command
+
+ :param eipconfig: eip configuration object
+ :type eipconfig: EIPConfig
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+
+ :return: A VPN command ready to be launched
+ :rtype: list
+ """
+ return []
+
+ @abstractmethod
+ def get_vpn_env(self, providerconfig):
+ """
+ Returns a dictionary with the custom env for the platform.
+ This is mainly used for setting LD_LIBRARY_PATH to the correct
+ path when distributing a standalone client
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :rtype: dict
+ """
+ return {}
+
+ @classmethod
+ def missing_updown_scripts(kls):
+ """
+ Returns what updown scripts are missing.
+ :rtype: list
+ """
+ leap_assert(kls.UPDOWN_FILES is not None,
+ "Need to define UPDOWN_FILES for this particular "
+ "auncher before calling this method")
+ file_exist = partial(_has_updown_scripts, warn=False)
+ zipped = zip(kls.UPDOWN_FILES, map(file_exist, kls.UPDOWN_FILES))
+ missing = filter(lambda (path, exists): exists is False, zipped)
+ return [path for path, exists in missing]
+
+ @classmethod
+ def missing_other_files(kls):
+ """
+ Returns what other important files are missing during startup.
+ Same as missing_updown_scripts but does not check for exec bit.
+ :rtype: list
+ """
+ leap_assert(kls.UPDOWN_FILES is not None,
+ "Need to define OTHER_FILES for this particular "
+ "auncher before calling this method")
+ file_exist = partial(_has_other_files, warn=False)
+ zipped = zip(kls.OTHER_FILES, map(file_exist, kls.OTHER_FILES))
+ missing = filter(lambda (path, exists): exists is False, zipped)
+ return [path for path, exists in missing]
+
+
+def get_platform_launcher():
+ launcher = globals()[platform.system() + "VPNLauncher"]
+ leap_assert(launcher, "Unimplemented platform launcher: %s" %
+ (platform.system(),))
+ return launcher()
+
+
+def _is_pkexec_in_system():
+ """
+ Checks the existence of the pkexec binary in system.
+ """
+ pkexec_path = which('pkexec')
+ if len(pkexec_path) == 0:
+ return False
+ return True
+
+
+def _has_updown_scripts(path, warn=True):
+ """
+ Checks the existence of the up/down scripts and its
+ exec bit if applicable.
+
+ :param path: the path to be checked
+ :type path: str
+
+ :param warn: whether we should log the absence
+ :type warn: bool
+
+ :rtype: bool
+ """
+ is_file = os.path.isfile(path)
+ if warn and not is_file:
+ logger.error("Could not find up/down script %s. "
+ "Might produce DNS leaks." % (path,))
+
+ # XXX check if applies in win
+ is_exe = False
+ try:
+ is_exe = (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] != 0)
+ except OSError as e:
+ logger.warn("%s" % (e,))
+ if warn and not is_exe:
+ logger.error("Up/down script %s is not executable. "
+ "Might produce DNS leaks." % (path,))
+ return is_file and is_exe
+
+
+def _has_other_files(path, warn=True):
+ """
+ Checks the existence of other important files.
+
+ :param path: the path to be checked
+ :type path: str
+
+ :param warn: whether we should log the absence
+ :type warn: bool
+
+ :rtype: bool
+ """
+ is_file = os.path.isfile(path)
+ if warn and not is_file:
+ logger.warning("Could not find file during checks: %s. " % (
+ path,))
+ return is_file
+
+
+def _is_auth_agent_running():
+ """
+ Checks if a polkit daemon is running.
+
+ :return: True if it's running, False if it's not.
+ :rtype: boolean
+ """
+ ps = 'ps aux | grep polkit-%s-authentication-agent-1'
+ opts = (ps % case for case in ['[g]nome', '[k]de'])
+ is_running = map(lambda l: commands.getoutput(l), opts)
+ return any(is_running)
+
+
+def _try_to_launch_agent():
+ """
+ Tries to launch a polkit daemon.
+ """
+ opts = [
+ "/usr/lib/policykit-1-gnome/polkit-gnome-authentication-agent-1",
+ # XXX add kde thing here
+ ]
+ for cmd in opts:
+ try:
+ subprocess.Popen([cmd], shell=True)
+ except:
+ pass
+
+
+class LinuxVPNLauncher(VPNLauncher):
+ """
+ VPN launcher for the Linux platform
+ """
+
+ PKEXEC_BIN = 'pkexec'
+ OPENVPN_BIN = 'openvpn'
+ SYSTEM_CONFIG = "/etc/leap"
+ UP_DOWN_FILE = "resolv-update"
+ UP_DOWN_PATH = "%s/%s" % (SYSTEM_CONFIG, UP_DOWN_FILE)
+
+ # We assume this is there by our openvpn dependency, and
+ # we will put it there on the bundle too.
+ # TODO adapt to the bundle path.
+ OPENVPN_DOWN_ROOT_BASE = "/usr/lib/openvpn/"
+ OPENVPN_DOWN_ROOT_FILE = "openvpn-plugin-down-root.so"
+ OPENVPN_DOWN_ROOT_PATH = "%s/%s" % (
+ OPENVPN_DOWN_ROOT_BASE,
+ OPENVPN_DOWN_ROOT_FILE)
+
+ POLKIT_BASE = "/usr/share/polkit-1/actions"
+ POLKIT_FILE = "net.openvpn.gui.leap.policy"
+ POLKIT_PATH = "%s/%s" % (POLKIT_BASE, POLKIT_FILE)
+
+ UPDOWN_FILES = (UP_DOWN_PATH,)
+ OTHER_FILES = (POLKIT_PATH,)
+
+ @classmethod
+ def cmd_for_missing_scripts(kls, frompath):
+ """
+ Returns a command that can copy the missing scripts.
+ :rtype: str
+ """
+ to = kls.SYSTEM_CONFIG
+ cmd = "#!/bin/sh\nset -e\nmkdir -p %s\ncp %s/%s %s\ncp %s/%s %s" % (
+ to,
+ frompath, kls.UP_DOWN_FILE, to,
+ frompath, kls.POLKIT_FILE, kls.POLKIT_PATH)
+ return cmd
+
+ @classmethod
+ def maybe_pkexec(kls):
+ """
+ Checks whether pkexec is available in the system, and
+ returns the path if found.
+
+ Might raise EIPNoPkexecAvailable or EIPNoPolkitAuthAgentAvailable
+
+ :returns: a list of the paths where pkexec is to be found
+ :rtype: list
+ """
+ if _is_pkexec_in_system():
+ if not _is_auth_agent_running():
+ _try_to_launch_agent()
+ if _is_auth_agent_running():
+ pkexec_possibilities = which(kls.PKEXEC_BIN)
+ leap_assert(len(pkexec_possibilities) > 0,
+ "We couldn't find pkexec")
+ return pkexec_possibilities
+ else:
+ logger.warning("No polkit auth agent found. pkexec " +
+ "will use its own auth agent.")
+ raise EIPNoPolkitAuthAgentAvailable()
+ else:
+ logger.warning("System has no pkexec")
+ raise EIPNoPkexecAvailable()
+
+ @classmethod
+ def maybe_down_plugin(kls):
+ """
+ Returns the path of the openvpn down-root-plugin, searching first
+ in the relative path for the standalone bundle, and then in the system
+ path where the debian package puts it.
+
+ :returns: the path where the plugin was found, or None
+ :rtype: str or None
+ """
+ cwd = os.getcwd()
+ rel_path_in_bundle = os.path.join(
+ 'apps', 'eip', 'files', kls.OPENVPN_DOWN_ROOT_FILE)
+ abs_path_in_bundle = os.path.join(cwd, rel_path_in_bundle)
+ if os.path.isfile(abs_path_in_bundle):
+ return abs_path_in_bundle
+ abs_path_in_system = kls.OPENVPN_DOWN_ROOT_FILE
+ if os.path.isfile(abs_path_in_system):
+ return abs_path_in_system
+
+ logger.warning("We could not find the down-root-plugin, so no updown "
+ "scripts will be run. DNS leaks are likely!")
+ return None
+
+ def get_vpn_command(self, eipconfig=None, providerconfig=None,
+ socket_host=None, socket_port="unix"):
+ """
+ Returns the platform dependant vpn launching command. It will
+ look for openvpn in the regular paths and algo in
+ path_prefix/apps/eip/ (in case standalone is set)
+
+ Might raise VPNException.
+
+ :param eipconfig: eip configuration object
+ :type eipconfig: EIPConfig
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+
+ :return: A VPN command ready to be launched
+ :rtype: list
+ """
+ leap_assert(eipconfig, "We need an eip config")
+ leap_assert_type(eipconfig, EIPConfig)
+ leap_assert(providerconfig, "We need a provider config")
+ leap_assert_type(providerconfig, ProviderConfig)
+ leap_assert(socket_host, "We need a socket host!")
+ leap_assert(socket_port, "We need a socket port!")
+
+ kwargs = {}
+ if ProviderConfig.standalone:
+ kwargs['path_extension'] = os.path.join(
+ providerconfig.get_path_prefix(),
+ "..", "apps", "eip")
+
+ openvpn_possibilities = which(self.OPENVPN_BIN, **kwargs)
+
+ if len(openvpn_possibilities) == 0:
+ raise OpenVPNNotFoundException()
+
+ openvpn = first(openvpn_possibilities)
+ args = []
+
+ pkexec = self.maybe_pkexec()
+ if pkexec:
+ args.append(openvpn)
+ openvpn = first(pkexec)
+
+ # TODO: handle verbosity
+
+ gateway_selector = VPNGatewaySelector(eipconfig)
+ gateways = gateway_selector.get_gateways()
+
+ logger.debug("Using gateways ips: {}".format(', '.join(gateways)))
+
+ for gw in gateways:
+ args += ['--remote', gw, '1194', 'udp']
+
+ args += [
+ '--client',
+ '--dev', 'tun',
+ '--persist-tun',
+ '--persist-key',
+ '--tls-client',
+ '--remote-cert-tls',
+ 'server'
+ ]
+
+ openvpn_configuration = eipconfig.get_openvpn_configuration()
+
+ for key, value in openvpn_configuration.items():
+ args += ['--%s' % (key,), value]
+
+ args += [
+ '--user', getpass.getuser(),
+ '--group', grp.getgrgid(os.getgroups()[-1]).gr_name
+ ]
+
+ if socket_port == "unix": # that's always the case for linux
+ args += [
+ '--management-client-user', getpass.getuser()
+ ]
+
+ args += [
+ '--management-signal',
+ '--management', socket_host, socket_port,
+ '--script-security', '2'
+ ]
+
+ plugin_path = self.maybe_down_plugin()
+ # If we do not have the down plugin neither in the bundle
+ # nor in the system, we do not do updown scripts. The alternative
+ # is leaving the user without the ability to restore dns and routes
+ # to its original state.
+
+ if plugin_path and _has_updown_scripts(self.UP_DOWN_PATH):
+ args += [
+ '--up', self.UP_DOWN_PATH,
+ '--down', self.UP_DOWN_PATH,
+ '--plugin', plugin_path,
+ '\'script_type=down %s\'' % self.UP_DOWN_PATH
+ ]
+
+ args += [
+ '--cert', eipconfig.get_client_cert_path(providerconfig),
+ '--key', eipconfig.get_client_cert_path(providerconfig),
+ '--ca', providerconfig.get_ca_cert_path()
+ ]
+
+ logger.debug("Running VPN with command:")
+ logger.debug("%s %s" % (openvpn, " ".join(args)))
+
+ return [openvpn] + args
+
+ def get_vpn_env(self, providerconfig):
+ """
+ Returns a dictionary with the custom env for the platform.
+ This is mainly used for setting LD_LIBRARY_PATH to the correct
+ path when distributing a standalone client
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :rtype: dict
+ """
+ leap_assert(providerconfig, "We need a provider config")
+ leap_assert_type(providerconfig, ProviderConfig)
+
+ return {"LD_LIBRARY_PATH": os.path.join(
+ providerconfig.get_path_prefix(),
+ "..", "lib")}
+
+
+class DarwinVPNLauncher(VPNLauncher):
+ """
+ VPN launcher for the Darwin Platform
+ """
+
+ COCOASUDO = "cocoasudo"
+ # XXX need magic translate for this string
+ SUDO_MSG = ("LEAP needs administrative privileges to run "
+ "Encrypted Internet.")
+
+ INSTALL_PATH = "/Applications/LEAP\ Client.app"
+ OPENVPN_BIN = 'openvpn.leap'
+ OPENVPN_PATH = "%s/Contents/Resources/openvpn" % (INSTALL_PATH,)
+
+ UP_SCRIPT = "%s/client.up.sh" % (OPENVPN_PATH,)
+ DOWN_SCRIPT = "%s/client.down.sh" % (OPENVPN_PATH,)
+ OPENVPN_DOWN_PLUGIN = '%s/openvpn-down-root.so' % (OPENVPN_PATH,)
+
+ UPDOWN_FILES = (UP_SCRIPT, DOWN_SCRIPT, OPENVPN_DOWN_PLUGIN)
+
+ @classmethod
+ def cmd_for_missing_scripts(kls, frompath):
+ """
+ Returns a command that can copy the missing scripts.
+ :rtype: str
+ """
+ to = kls.OPENVPN_PATH
+ cmd = "#!/bin/sh\nmkdir -p %s\ncp \"%s/\"* %s" % (to, frompath, to)
+ return cmd
+
+ def get_cocoasudo_cmd(self):
+ """
+ Returns a string with the cocoasudo command needed to run openvpn
+ as admin with a nice password prompt. The actual command needs to be
+ appended.
+
+ :rtype: (str, list)
+ """
+ iconpath = os.path.abspath(os.path.join(
+ os.getcwd(),
+ "../../../Resources/leap-client.tiff"))
+ has_icon = os.path.isfile(iconpath)
+ args = ["--icon=%s" % iconpath] if has_icon else []
+ args.append("--prompt=%s" % (self.SUDO_MSG,))
+
+ return self.COCOASUDO, args
+
+ def get_vpn_command(self, eipconfig=None, providerconfig=None,
+ socket_host=None, socket_port="unix"):
+ """
+ Returns the platform dependant vpn launching command
+
+ Might raise VPNException.
+
+ :param eipconfig: eip configuration object
+ :type eipconfig: EIPConfig
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+
+ :return: A VPN command ready to be launched
+ :rtype: list
+ """
+ leap_assert(eipconfig, "We need an eip config")
+ leap_assert_type(eipconfig, EIPConfig)
+ leap_assert(providerconfig, "We need a provider config")
+ leap_assert_type(providerconfig, ProviderConfig)
+ leap_assert(socket_host, "We need a socket host!")
+ leap_assert(socket_port, "We need a socket port!")
+
+ kwargs = {}
+ if ProviderConfig.standalone:
+ kwargs['path_extension'] = os.path.join(
+ providerconfig.get_path_prefix(),
+ "..", "apps", "eip")
+
+ openvpn_possibilities = which(
+ self.OPENVPN_BIN,
+ **kwargs)
+ if len(openvpn_possibilities) == 0:
+ raise OpenVPNNotFoundException()
+
+ openvpn = first(openvpn_possibilities)
+ args = [openvpn]
+
+ # TODO: handle verbosity
+
+ gateway_selector = VPNGatewaySelector(eipconfig)
+ gateways = gateway_selector.get_gateways()
+
+ logger.debug("Using gateways ips: {gw}".format(
+ gw=', '.join(gateways)))
+
+ for gw in gateways:
+ args += ['--remote', gw, '1194', 'udp']
+
+ args += [
+ '--client',
+ '--dev', 'tun',
+ '--persist-tun',
+ '--persist-key',
+ '--tls-client',
+ '--remote-cert-tls',
+ 'server'
+ ]
+
+ openvpn_configuration = eipconfig.get_openvpn_configuration()
+ for key, value in openvpn_configuration.items():
+ args += ['--%s' % (key,), value]
+
+ user = getpass.getuser()
+ args += [
+ '--user', user,
+ '--group', grp.getgrgid(os.getgroups()[-1]).gr_name
+ ]
+
+ if socket_port == "unix":
+ args += [
+ '--management-client-user', user
+ ]
+
+ args += [
+ '--management-signal',
+ '--management', socket_host, socket_port,
+ '--script-security', '2'
+ ]
+
+ if _has_updown_scripts(self.UP_SCRIPT):
+ args += [
+ '--up', self.UP_SCRIPT,
+ ]
+
+ if _has_updown_scripts(self.DOWN_SCRIPT):
+ args += [
+ '--down', self.DOWN_SCRIPT]
+
+ # should have the down script too
+ if _has_updown_scripts(self.OPENVPN_DOWN_PLUGIN):
+ args += [
+ '--plugin', self.OPENVPN_DOWN_PLUGIN,
+ '\'%s\'' % self.DOWN_SCRIPT
+ ]
+
+ # we set user to be passed to the up/down scripts
+ args += [
+ '--setenv', "LEAPUSER", "%s" % (user,)]
+
+ args += [
+ '--cert', eipconfig.get_client_cert_path(providerconfig),
+ '--key', eipconfig.get_client_cert_path(providerconfig),
+ '--ca', providerconfig.get_ca_cert_path()
+ ]
+
+ command, cargs = self.get_cocoasudo_cmd()
+ cmd_args = cargs + args
+
+ logger.debug("Running VPN with command:")
+ logger.debug("%s %s" % (command, " ".join(cmd_args)))
+
+ return [command] + cmd_args
+
+ def get_vpn_env(self, providerconfig):
+ """
+ Returns a dictionary with the custom env for the platform.
+ This is mainly used for setting LD_LIBRARY_PATH to the correct
+ path when distributing a standalone client
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :rtype: dict
+ """
+ return {"DYLD_LIBRARY_PATH": os.path.join(
+ providerconfig.get_path_prefix(),
+ "..", "lib")}
+
+
+class WindowsVPNLauncher(VPNLauncher):
+ """
+ VPN launcher for the Windows platform
+ """
+
+ OPENVPN_BIN = 'openvpn_leap.exe'
+
+ # XXX UPDOWN_FILES ... we do not have updown files defined yet!
+
+ def get_vpn_command(self, eipconfig=None, providerconfig=None,
+ socket_host=None, socket_port="9876"):
+ """
+ Returns the platform dependant vpn launching command. It will
+ look for openvpn in the regular paths and algo in
+ path_prefix/apps/eip/ (in case standalone is set)
+
+ Might raise VPNException.
+
+ :param eipconfig: eip configuration object
+ :type eipconfig: EIPConfig
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+
+ :return: A VPN command ready to be launched
+ :rtype: list
+ """
+ leap_assert(eipconfig, "We need an eip config")
+ leap_assert_type(eipconfig, EIPConfig)
+ leap_assert(providerconfig, "We need a provider config")
+ leap_assert_type(providerconfig, ProviderConfig)
+ leap_assert(socket_host, "We need a socket host!")
+ leap_assert(socket_port, "We need a socket port!")
+ leap_assert(socket_port != "unix",
+ "We cannot use unix sockets in windows!")
+
+ openvpn_possibilities = which(
+ self.OPENVPN_BIN,
+ path_extension=os.path.join(providerconfig.get_path_prefix(),
+ "..", "apps", "eip"))
+
+ if len(openvpn_possibilities) == 0:
+ raise OpenVPNNotFoundException()
+
+ openvpn = first(openvpn_possibilities)
+ args = []
+
+ # TODO: handle verbosity
+
+ gateway_selector = VPNGatewaySelector(eipconfig)
+ gateways = gateway_selector.get_gateways()
+
+ logger.debug("Using gateways ips: {}".format(', '.join(gateways)))
+
+ for gw in gateways:
+ args += ['--remote', gw, '1194', 'udp']
+
+ args += [
+ '--client',
+ '--dev', 'tun',
+ '--persist-tun',
+ '--persist-key',
+ '--tls-client',
+ '--remote-cert-tls',
+ 'server'
+ ]
+
+ openvpn_configuration = eipconfig.get_openvpn_configuration()
+ for key, value in openvpn_configuration.items():
+ args += ['--%s' % (key,), value]
+
+ args += [
+ '--user', getpass.getuser(),
+ #'--group', grp.getgrgid(os.getgroups()[-1]).gr_name
+ ]
+ args += [
+ '--management-signal',
+ '--management', socket_host, socket_port,
+ '--script-security', '2'
+ ]
+ args += [
+ '--cert', eipconfig.get_client_cert_path(providerconfig),
+ '--key', eipconfig.get_client_cert_path(providerconfig),
+ '--ca', providerconfig.get_ca_cert_path()
+ ]
+
+ logger.debug("Running VPN with command:")
+ logger.debug("%s %s" % (openvpn, " ".join(args)))
+
+ return [openvpn] + args
+
+ def get_vpn_env(self, providerconfig):
+ """
+ Returns a dictionary with the custom env for the platform.
+ This is mainly used for setting LD_LIBRARY_PATH to the correct
+ path when distributing a standalone client
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :rtype: dict
+ """
+ return {}
+
+
+if __name__ == "__main__":
+ logger = logging.getLogger(name='leap')
+ logger.setLevel(logging.DEBUG)
+ console = logging.StreamHandler()
+ console.setLevel(logging.DEBUG)
+ formatter = logging.Formatter(
+ '%(asctime)s '
+ '- %(name)s - %(levelname)s - %(message)s')
+ console.setFormatter(formatter)
+ logger.addHandler(console)
+
+ try:
+ abs_launcher = VPNLauncher()
+ except Exception as e:
+ assert isinstance(e, TypeError), "Something went wrong"
+ print "Abstract Prefixer class is working as expected"
+
+ vpnlauncher = get_platform_launcher()
+
+ eipconfig = EIPConfig()
+ if eipconfig.load("leap/providers/bitmask.net/eip-service.json"):
+ provider = ProviderConfig()
+ if provider.load("leap/providers/bitmask.net/provider.json"):
+ vpnlauncher.get_vpn_command(eipconfig=eipconfig,
+ providerconfig=provider,
+ socket_host="/blah")
diff --git a/src/leap/services/eip/vpnprocess.py b/src/leap/services/eip/vpnprocess.py
new file mode 100644
index 00000000..0ec56ae7
--- /dev/null
+++ b/src/leap/services/eip/vpnprocess.py
@@ -0,0 +1,718 @@
+# -*- coding: utf-8 -*-
+# vpnprocess.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+VPN Manager, spawned in a custom processProtocol.
+"""
+import logging
+import os
+import psutil
+import psutil.error
+import shutil
+import socket
+
+from PySide import QtCore
+
+from leap.common.check import leap_assert, leap_assert_type
+from leap.config.providerconfig import ProviderConfig
+from leap.services.eip.vpnlaunchers import get_platform_launcher
+from leap.services.eip.eipconfig import EIPConfig
+from leap.services.eip.udstelnet import UDSTelnet
+
+logger = logging.getLogger(__name__)
+vpnlog = logging.getLogger('leap.openvpn')
+
+from twisted.internet import protocol
+from twisted.internet import defer
+from twisted.internet.task import LoopingCall
+from twisted.internet import error as internet_error
+
+
+class VPNSignals(QtCore.QObject):
+ """
+ These are the signals that we use to let the UI know
+ about the events we are polling.
+ They are instantiated in the VPN object and passed along
+ till the VPNProcess.
+ """
+ state_changed = QtCore.Signal(dict)
+ status_changed = QtCore.Signal(dict)
+ process_finished = QtCore.Signal(int)
+
+ def __init__(self):
+ QtCore.QObject.__init__(self)
+
+
+class VPN(object):
+ """
+ This is the high-level object that the GUI is dealing with.
+ It exposes the start and terminate methods.
+
+ On start, it spawns a VPNProcess instance that will use a vpnlauncher
+ suited for the running platform and connect to the management interface
+ opened by the openvpn process, executing commands over that interface on
+ demand.
+ """
+ TERMINATE_MAXTRIES = 10
+ TERMINATE_WAIT = 1 # secs
+
+ def __init__(self):
+ """
+ Instantiate empty attributes and get a copy
+ of a QObject containing the QSignals that we will pass along
+ to the VPNManager.
+ """
+ from twisted.internet import reactor
+ self._vpnproc = None
+ self._pollers = []
+ self._reactor = reactor
+ self._qtsigs = VPNSignals()
+
+ @property
+ def qtsigs(self):
+ return self._qtsigs
+
+ def start(self, *args, **kwargs):
+ """
+ Starts the openvpn subprocess.
+
+ :param args: args to be passed to the VPNProcess
+ :type args: tuple
+
+ :param kwargs: kwargs to be passed to the VPNProcess
+ :type kwargs: dict
+ """
+ kwargs['qtsigs'] = self.qtsigs
+
+ # start the main vpn subprocess
+ vpnproc = VPNProcess(*args, **kwargs)
+
+ # XXX Should stop if already running -------
+ if vpnproc.get_openvpn_process():
+ logger.warning("Another vpnprocess is running!")
+
+ cmd = vpnproc.getCommand()
+ env = os.environ
+ for key, val in vpnproc.vpn_env.items():
+ env[key] = val
+
+ self._reactor.spawnProcess(vpnproc, cmd[0], cmd, env)
+ self._vpnproc = vpnproc
+
+ # add pollers for status and state
+ # this could be extended to a collection of
+ # generic watchers
+
+ poll_list = [LoopingCall(vpnproc.pollStatus),
+ LoopingCall(vpnproc.pollState)]
+ self._pollers.extend(poll_list)
+ self._start_pollers()
+
+ def _kill_if_left_alive(self, tries=0):
+ """
+ Check if the process is still alive, and sends a
+ SIGKILL after a timeout period.
+
+ :param tries: counter of tries, used in recursion
+ :type tries: int
+ """
+ from twisted.internet import reactor
+ while tries < self.TERMINATE_MAXTRIES:
+ if self._vpnproc.transport.pid is None:
+ logger.debug("Process has been happily terminated.")
+ return
+ else:
+ logger.debug("Process did not die, waiting...")
+ tries += 1
+ reactor.callLater(self.TERMINATE_WAIT,
+ self._kill_if_left_alive, tries)
+
+ # after running out of patience, we try a killProcess
+ logger.debug("Process did not died. Sending a SIGKILL.")
+ self.killit()
+
+ def killit(self):
+ """
+ Sends a kill signal to the process.
+ """
+ self._stop_pollers()
+ self._vpnproc.aborted = True
+ self._vpnproc.killProcess()
+
+ def terminate(self, shutdown=False):
+ """
+ Stops the openvpn subprocess.
+
+ Attempts to send a SIGTERM first, and after a timeout
+ it sends a SIGKILL.
+ """
+ from twisted.internet import reactor
+ self._stop_pollers()
+
+ # First we try to be polite and send a SIGTERM...
+ if self._vpnproc:
+ self._sentterm = True
+ self._vpnproc.terminate_openvpn(shutdown=shutdown)
+
+ # ...but we also trigger a countdown to be unpolite
+ # if strictly needed.
+ reactor.callLater(
+ self.TERMINATE_WAIT, self._kill_if_left_alive)
+
+ def _start_pollers(self):
+ """
+ Iterate through the registered observers
+ and start the looping call for them.
+ """
+ for poller in self._pollers:
+ poller.start(VPNManager.POLL_TIME)
+
+ def _stop_pollers(self):
+ """
+ Iterate through the registered observers
+ and stop the looping calls if they are running.
+ """
+ for poller in self._pollers:
+ if poller.running:
+ poller.stop()
+ self._pollers = []
+
+
+class VPNManager(object):
+ """
+ This is a mixin that we use in the VPNProcess class.
+ Here we get together all methods related with the openvpn management
+ interface.
+
+ A copy of a QObject containing signals as attributes is passed along
+ upon initialization, and we use that object to emit signals to qt-land.
+
+ For more info about management methods::
+
+ zcat `dpkg -L openvpn | grep management`
+ """
+
+ # Timers, in secs
+ POLL_TIME = 0.5
+ CONNECTION_RETRY_TIME = 1
+
+ TS_KEY = "ts"
+ STATUS_STEP_KEY = "status_step"
+ OK_KEY = "ok"
+ IP_KEY = "ip"
+ REMOTE_KEY = "remote"
+
+ TUNTAP_READ_KEY = "tun_tap_read"
+ TUNTAP_WRITE_KEY = "tun_tap_write"
+ TCPUDP_READ_KEY = "tcp_udp_read"
+ TCPUDP_WRITE_KEY = "tcp_udp_write"
+ AUTH_READ_KEY = "auth_read"
+
+ def __init__(self, qtsigs=None):
+ """
+ Initializes the VPNManager.
+
+ :param qtsigs: a QObject containing the Qt signals used by the UI
+ to give feedback about state changes.
+ :type qtsigs: QObject
+ """
+ from twisted.internet import reactor
+ self._reactor = reactor
+ self._tn = None
+ self._qtsigs = qtsigs
+ self._aborted = False
+
+ @property
+ def qtsigs(self):
+ return self._qtsigs
+
+ @property
+ def aborted(self):
+ return self._aborted
+
+ @aborted.setter
+ def aborted(self, value):
+ self._aborted = value
+
+ def _seek_to_eof(self):
+ """
+ Read as much as available. Position seek pointer to end of stream
+ """
+ try:
+ self._tn.read_eager()
+ except EOFError:
+ logger.debug("Could not read from socket. Assuming it died.")
+ return
+
+ def _send_command(self, command, until=b"END"):
+ """
+ Sends a command to the telnet connection and reads until END
+ is reached.
+
+ :param command: command to send
+ :type command: str
+
+ :param until: byte delimiter string for reading command output
+ :type until: byte str
+
+ :return: response read
+ :rtype: list
+ """
+ leap_assert(self._tn, "We need a tn connection!")
+
+ try:
+ self._tn.write("%s\n" % (command,))
+ buf = self._tn.read_until(until, 2)
+ self._seek_to_eof()
+ blist = buf.split('\r\n')
+ if blist[-1].startswith(until):
+ del blist[-1]
+ return blist
+ else:
+ return []
+
+ except socket.error:
+ # XXX should get a counter and repeat only
+ # after mod X times.
+ logger.warning('socket error')
+ self._close_management_socket(announce=False)
+ return []
+
+ # XXX should move this to a errBack!
+ except Exception as e:
+ logger.warning("Error sending command %s: %r" %
+ (command, e))
+ return []
+
+ def _close_management_socket(self, announce=True):
+ """
+ Close connection to openvpn management interface.
+ """
+ logger.debug('closing socket')
+ if announce:
+ self._tn.write("quit\n")
+ self._tn.read_all()
+ self._tn.get_socket().close()
+ self._tn = None
+
+ def _connect_management(self, socket_host, socket_port):
+ """
+ Connects to the management interface on the specified
+ socket_host socket_port.
+
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+ """
+ if self.is_connected():
+ self._close_management_socket()
+
+ try:
+ self._tn = UDSTelnet(socket_host, socket_port)
+
+ # XXX make password optional
+ # specially for win. we should generate
+ # the pass on the fly when invoking manager
+ # from conductor
+
+ # self.tn.read_until('ENTER PASSWORD:', 2)
+ # self.tn.write(self.password + '\n')
+ # self.tn.read_until('SUCCESS:', 2)
+ if self._tn:
+ self._tn.read_eager()
+
+ # XXX move this to the Errback
+ except Exception as e:
+ logger.warning("Could not connect to OpenVPN yet: %r" % (e,))
+ self._tn = None
+
+ def _connectCb(self, *args):
+ """
+ Callback for connection.
+
+ :param args: not used
+ """
+ if self._tn:
+ logger.info('connected to management')
+
+ def _connectErr(self, failure):
+ """
+ Errorback for connection.
+
+ :param failure: Failure
+ """
+ logger.warning(failure)
+
+ def connect_to_management(self, host, port):
+ """
+ Connect to a management interface.
+
+ :param host: the host of the management interface
+ :type host: str
+
+ :param port: the port of the management interface
+ :type port: str
+
+ :returns: a deferred
+ """
+ self.connectd = defer.maybeDeferred(
+ self._connect_management, host, port)
+ self.connectd.addCallbacks(self._connectCb, self._connectErr)
+ return self.connectd
+
+ def is_connected(self):
+ """
+ Returns the status of the management interface.
+
+ :returns: True if connected, False otherwise
+ :rtype: bool
+ """
+ return True if self._tn else False
+
+ def try_to_connect_to_management(self, retry=0):
+ """
+ Attempts to connect to a management interface, and retries
+ after CONNECTION_RETRY_TIME if not successful.
+
+ :param retry: number of the retry
+ :type retry: int
+ """
+ # TODO decide about putting a max_lim to retries and signaling
+ # an error.
+ if not self.aborted and not self.is_connected():
+ self.connect_to_management(self._socket_host, self._socket_port)
+ self._reactor.callLater(
+ self.CONNECTION_RETRY_TIME,
+ self.try_to_connect_to_management, retry + 1)
+
+ def _parse_state_and_notify(self, output):
+ """
+ Parses the output of the state command and emits state_changed
+ signal when the state changes.
+
+ :param output: list of lines that the state command printed as
+ its output
+ :type output: list
+ """
+ for line in output:
+ stripped = line.strip()
+ if stripped == "END":
+ continue
+ parts = stripped.split(",")
+ if len(parts) < 5:
+ continue
+ ts, status_step, ok, ip, remote = parts
+
+ state_dict = {
+ self.TS_KEY: ts,
+ self.STATUS_STEP_KEY: status_step,
+ self.OK_KEY: ok,
+ self.IP_KEY: ip,
+ self.REMOTE_KEY: remote
+ }
+
+ if state_dict != self._last_state:
+ self.qtsigs.state_changed.emit(state_dict)
+ self._last_state = state_dict
+
+ def _parse_status_and_notify(self, output):
+ """
+ Parses the output of the status command and emits
+ status_changed signal when the status changes.
+
+ :param output: list of lines that the status command printed
+ as its output
+ :type output: list
+ """
+ tun_tap_read = ""
+ tun_tap_write = ""
+ tcp_udp_read = ""
+ tcp_udp_write = ""
+ auth_read = ""
+ for line in output:
+ stripped = line.strip()
+ if stripped.endswith("STATISTICS") or stripped == "END":
+ continue
+ parts = stripped.split(",")
+ if len(parts) < 2:
+ continue
+ if parts[0].strip() == "TUN/TAP read bytes":
+ tun_tap_read = parts[1]
+ elif parts[0].strip() == "TUN/TAP write bytes":
+ tun_tap_write = parts[1]
+ elif parts[0].strip() == "TCP/UDP read bytes":
+ tcp_udp_read = parts[1]
+ elif parts[0].strip() == "TCP/UDP write bytes":
+ tcp_udp_write = parts[1]
+ elif parts[0].strip() == "Auth read bytes":
+ auth_read = parts[1]
+
+ status_dict = {
+ self.TUNTAP_READ_KEY: tun_tap_read,
+ self.TUNTAP_WRITE_KEY: tun_tap_write,
+ self.TCPUDP_READ_KEY: tcp_udp_read,
+ self.TCPUDP_WRITE_KEY: tcp_udp_write,
+ self.AUTH_READ_KEY: auth_read
+ }
+
+ if status_dict != self._last_status:
+ self.qtsigs.status_changed.emit(status_dict)
+ self._last_status = status_dict
+
+ def get_state(self):
+ """
+ Notifies the gui of the output of the state command over
+ the openvpn management interface.
+ """
+ if self.is_connected():
+ return self._parse_state_and_notify(self._send_command("state"))
+
+ def get_status(self):
+ """
+ Notifies the gui of the output of the status command over
+ the openvpn management interface.
+ """
+ if self.is_connected():
+ return self._parse_status_and_notify(self._send_command("status"))
+
+ @property
+ def vpn_env(self):
+ """
+ Return a dict containing the vpn environment to be used.
+ """
+ return self._launcher.get_vpn_env(self._providerconfig)
+
+ def terminate_openvpn(self, shutdown=False):
+ """
+ Attempts to terminate openvpn by sending a SIGTERM.
+ """
+ if self.is_connected():
+ self._send_command("signal SIGTERM")
+ if shutdown:
+ self._cleanup_tempfiles()
+
+ def _cleanup_tempfiles(self):
+ """
+ Remove all temporal files we might have left behind.
+
+ Iif self.port is 'unix', we have created a temporal socket path that,
+ under normal circumstances, we should be able to delete.
+ """
+ if self._socket_port == "unix":
+ logger.debug('cleaning socket file temp folder')
+ tempfolder = os.path.split(self._socket_host)[0] # XXX use `first`
+ if os.path.isdir(tempfolder):
+ try:
+ shutil.rmtree(tempfolder)
+ except OSError:
+ logger.error('could not delete tmpfolder %s' % tempfolder)
+
+ # ---------------------------------------------------
+ # XXX old methods, not adapted to twisted process yet
+
+ def get_openvpn_process(self):
+ """
+ Looks for openvpn instances running.
+
+ :rtype: process
+ """
+ openvpn_process = None
+ for p in psutil.process_iter():
+ try:
+ # XXX Not exact!
+ # Will give false positives.
+ # we should check that cmdline BEGINS
+ # with openvpn or with our wrapper
+ # (pkexec / osascript / whatever)
+ if "openvpn" in ' '.join(p.cmdline):
+ openvpn_process = p
+ break
+ except psutil.error.AccessDenied:
+ pass
+ return openvpn_process
+
+ def _stop_if_already_running(self):
+ """
+ Checks if VPN is already running and tries to stop it.
+
+ :return: True if stopped, False otherwise
+ """
+ # TODO cleanup this
+ process = self._get_openvpn_process()
+ if process:
+ logger.debug("OpenVPN is already running, trying to stop it...")
+ cmdline = process.cmdline
+
+ manag_flag = "--management"
+ if isinstance(cmdline, list) and manag_flag in cmdline:
+ try:
+ index = cmdline.index(manag_flag)
+ host = cmdline[index + 1]
+ port = cmdline[index + 2]
+ logger.debug("Trying to connect to %s:%s"
+ % (host, port))
+ self._connect_to_management(host, port)
+ self._send_command("signal SIGTERM")
+ self._tn.close()
+ self._tn = None
+ #self._disconnect_management()
+ except Exception as e:
+ logger.warning("Problem trying to terminate OpenVPN: %r"
+ % (e,))
+
+ process = self._get_openvpn_process()
+ if process is None:
+ logger.warning("Unabled to terminate OpenVPN")
+ return True
+ else:
+ return False
+ return True
+
+
+class VPNProcess(protocol.ProcessProtocol, VPNManager):
+ """
+ A ProcessProtocol class that can be used to spawn a process that will
+ launch openvpn and connect to its management interface to control it
+ programmatically.
+ """
+
+ def __init__(self, eipconfig, providerconfig, socket_host, socket_port,
+ qtsigs):
+ """
+ :param eipconfig: eip configuration object
+ :type eipconfig: EIPConfig
+
+ :param providerconfig: provider specific configuration
+ :type providerconfig: ProviderConfig
+
+ :param socket_host: either socket path (unix) or socket IP
+ :type socket_host: str
+
+ :param socket_port: either string "unix" if it's a unix
+ socket, or port otherwise
+ :type socket_port: str
+
+ :param qtsigs: a QObject containing the Qt signals used to notify the
+ UI.
+ :type qtsigs: QObject
+ """
+ VPNManager.__init__(self, qtsigs=qtsigs)
+ leap_assert_type(eipconfig, EIPConfig)
+ leap_assert_type(providerconfig, ProviderConfig)
+ leap_assert_type(qtsigs, QtCore.QObject)
+
+ #leap_assert(not self.isRunning(), "Starting process more than once!")
+
+ self._eipconfig = eipconfig
+ self._providerconfig = providerconfig
+ self._socket_host = socket_host
+ self._socket_port = socket_port
+
+ self._launcher = get_platform_launcher()
+
+ self._last_state = None
+ self._last_status = None
+ self._alive = False
+
+ # processProtocol methods
+
+ def connectionMade(self):
+ """
+ Called when the connection is made.
+
+ .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa
+ """
+ self._alive = True
+ self.aborted = False
+ self.try_to_connect_to_management()
+
+ def outReceived(self, data):
+ """
+ Called when new data is available on stdout.
+
+ :param data: the data read on stdout
+
+ .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa
+ """
+ # truncate the newline
+ # should send this to the logging window
+ vpnlog.info(data[:-1])
+
+ def processExited(self, reason):
+ """
+ Called when the child process exits.
+
+ .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa
+ """
+ exit_code = reason.value.exitCode
+ if isinstance(exit_code, int):
+ logger.debug("processExited, status %d" % (exit_code,))
+ self.qtsigs.process_finished.emit(exit_code)
+ self._alive = False
+
+ def processEnded(self, reason):
+ """
+ Called when the child process exits and all file descriptors associated
+ with it have been closed.
+
+ .. seeAlso: `http://twistedmatrix.com/documents/13.0.0/api/twisted.internet.protocol.ProcessProtocol.html` # noqa
+ """
+ exit_code = reason.value.exitCode
+ if isinstance(exit_code, int):
+ logger.debug("processEnded, status %d" % (exit_code,))
+
+ # polling
+
+ def pollStatus(self):
+ """
+ Polls connection status.
+ """
+ if self._alive:
+ self.get_status()
+
+ def pollState(self):
+ """
+ Polls connection state.
+ """
+ if self._alive:
+ self.get_state()
+
+ # launcher
+
+ def getCommand(self):
+ """
+ Gets the vpn command from the aproppriate launcher.
+ """
+ cmd = self._launcher.get_vpn_command(
+ eipconfig=self._eipconfig,
+ providerconfig=self._providerconfig,
+ socket_host=self._socket_host,
+ socket_port=self._socket_port)
+ return map(str, cmd)
+
+ # shutdown
+
+ def killProcess(self):
+ """
+ Sends the KILL signal to the running process.
+ """
+ try:
+ self.transport.signalProcess('KILL')
+ except internet_error.ProcessExitedAlready:
+ logger.debug('Process Exited Already')