summaryrefslogtreecommitdiff
path: root/src/leap/services/eip/eipconfig.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/services/eip/eipconfig.py')
-rw-r--r--src/leap/services/eip/eipconfig.py38
1 files changed, 35 insertions, 3 deletions
diff --git a/src/leap/services/eip/eipconfig.py b/src/leap/services/eip/eipconfig.py
index 4e74687a..0a7d2b23 100644
--- a/src/leap/services/eip/eipconfig.py
+++ b/src/leap/services/eip/eipconfig.py
@@ -18,8 +18,11 @@
"""
Provider configuration
"""
-import os
import logging
+import os
+import re
+
+import ipaddr
from leap.common.check import leap_assert, leap_assert_type
from leap.common.config.baseconfig import BaseConfig
@@ -33,6 +36,8 @@ class EIPConfig(BaseConfig):
"""
Provider configuration abstraction class
"""
+ OPENVPN_ALLOWED_KEYS = ("auth", "cipher", "tls-cipher")
+ OPENVPN_CIPHERS_REGEX = re.compile("[A-Z0-9\-]+")
def __init__(self):
BaseConfig.__init__(self)
@@ -52,7 +57,24 @@ class EIPConfig(BaseConfig):
return self._safe_get_value("gateways")
def get_openvpn_configuration(self):
- return self._safe_get_value("openvpn_configuration")
+ """
+ Returns a dictionary containing the openvpn configuration
+ parameters.
+
+ These are sanitized with alphanumeric whitelist.
+
+ @returns: openvpn configuration dict
+ @rtype: C{dict}
+ """
+ ovpncfg = self._safe_get_value("openvpn_configuration")
+ config = {}
+ for key, value in ovpncfg.items():
+ if key in self.OPENVPN_ALLOWED_KEYS and value is not None:
+ sanitized_val = self.OPENVPN_CIPHERS_REGEX.findall(value)
+ if len(sanitized_val) != 0:
+ _val = sanitized_val[0]
+ config[str(key)] = str(_val)
+ return config
def get_serial(self):
return self._safe_get_value("serial")
@@ -61,13 +83,23 @@ class EIPConfig(BaseConfig):
return self._safe_get_value("version")
def get_gateway_ip(self, index=0):
+ """
+ Returns the ip of the gateway
+ """
gateways = self.get_gateways()
leap_assert(len(gateways) > 0, "We don't have any gateway!")
if index > len(gateways):
index = 0
logger.warning("Provided an unknown gateway index %s, " +
"defaulting to 0")
- return gateways[0]["ip_address"]
+ ip_addr_str = gateways[0]["ip_address"]
+
+ try:
+ ipaddr.IPAddress(ip_addr_str)
+ return ip_addr_str
+ except ValueError:
+ logger.error("Invalid ip address in config: %s" % (ip_addr_str,))
+ return None
def get_client_cert_path(self,
providerconfig=None,