summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuben Pollan <meskio@sindominio.net>2017-08-30 19:54:37 +0200
committerRuben Pollan <meskio@sindominio.net>2017-08-31 12:56:43 +0200
commit663c87c221c42e081b5947e298bc9f0541e6913a (patch)
treed933acad5df919614e600aa1ae2e737ebcc6e472
parentf19bfeb73fc19747bd02cbbd5c024de4dc86b5a8 (diff)
[feat] list vpn gateways in the order that they are going to be used
-rw-r--r--src/leap/bitmask/cli/vpn.py14
-rw-r--r--src/leap/bitmask/vpn/gateways.py40
-rw-r--r--src/leap/bitmask/vpn/service.py36
-rw-r--r--tests/unit/vpn/test_gateways.py17
4 files changed, 55 insertions, 52 deletions
diff --git a/src/leap/bitmask/cli/vpn.py b/src/leap/bitmask/cli/vpn.py
index 5a9cce9..bcc2a44 100644
--- a/src/leap/bitmask/cli/vpn.py
+++ b/src/leap/bitmask/cli/vpn.py
@@ -125,10 +125,10 @@ def location_printer(result):
value + Fore.RESET)
for provider, locations in result.items():
- for loc in locations.values():
- location_str = ("[%(country_code)s] %(name)s "
- "(UTC%(timezone)s %(hemisphere)s)" % loc)
- pprint(provider, location_str)
-
- if not locations.values():
- pprint(provider, "---")
+ for loc in locations:
+ if 'name' not in loc:
+ pprint(provider, "---")
+ else:
+ location_str = ("[%(country_code)s] %(name)s "
+ "(UTC%(timezone)s %(hemisphere)s)" % loc)
+ pprint(provider, location_str)
diff --git a/src/leap/bitmask/vpn/gateways.py b/src/leap/bitmask/vpn/gateways.py
index a1be1c1..ac75d9a 100644
--- a/src/leap/bitmask/vpn/gateways.py
+++ b/src/leap/bitmask/vpn/gateways.py
@@ -77,7 +77,8 @@ class GatewaySelector(object):
"""
Returns the IPs top 4 preferred gateways, in order.
"""
- gateways = [gateway[1] for gateway in self.get_sorted_gateways()][:4]
+ gateways = [gateway['ip_address']
+ for gateway in self.get_sorted_gateways()][:4]
return gateways
def get_sorted_gateways(self):
@@ -92,28 +93,21 @@ class GatewaySelector(object):
distance = 99 # if hasn't location -> should go last
location = locations.get(gateway.get('location'))
- label = gateway.get('location', 'Unknown')
- country = 'XX'
+ gateway = gateway.copy()
if location is not None:
- country = location.get('country_code', 'XX')
- label = location.get('name', label)
+ gateway.update(location)
timezone = location.get('timezone')
if timezone is not None:
offset = int(timezone)
if offset in self.equivalent_timezones:
offset = self.equivalent_timezones[offset]
distance = self._get_timezone_distance(offset)
- ip = self.gateways[idx].get('ip_address')
- gateways_timezones.append((ip, distance, label, country))
+ gateway['distance'] = distance
+ gateways_timezones.append(gateway)
- gateways_timezones = sorted(gateways_timezones, key=lambda gw: gw[1])
-
- result = []
- for ip, distance, label, country in gateways_timezones:
- result.append((label, ip, country))
-
- filtered = self.apply_user_preferences(result)
- return filtered
+ gateways_timezones = sorted(gateways_timezones,
+ key=lambda gw: gw['distance'])
+ return self.apply_user_preferences(gateways_timezones)
def apply_user_preferences(self, options):
"""
@@ -125,17 +119,17 @@ class GatewaySelector(object):
applied = []
presorted = copy.copy(options)
for location in self.preferred.get('loc', []):
- for index, data in enumerate(presorted):
- label, ip, country = data
- if _normalized(label) == _normalized(location):
- applied.append((label, ip, country))
+ for index, gw in enumerate(presorted):
+ if ('location' in gw and
+ _normalized(gw['location']) == _normalized(location)):
+ applied.append(gw)
presorted.pop(index)
for cc in self.preferred.get('cc', []):
- for index, data in enumerate(presorted):
- label, ip, country = data
- if _normalized(country) == _normalized(cc):
- applied.append((label, ip, country))
+ for index, gw in enumerate(presorted):
+ if ('country_code' in gw and
+ _normalized(gw['country_code']) == _normalized(cc)):
+ applied.append(gw)
presorted.pop(index)
if presorted:
applied += presorted
diff --git a/src/leap/bitmask/vpn/service.py b/src/leap/bitmask/vpn/service.py
index f4af303..6588e1d 100644
--- a/src/leap/bitmask/vpn/service.py
+++ b/src/leap/bitmask/vpn/service.py
@@ -194,7 +194,8 @@ class VPNService(HookableService):
config = yield bonafide.do_provider_read(provider, 'eip')
except ValueError:
continue
- provider_dict[provider] = config.locations
+ gateways = self._gateways(config)
+ provider_dict[provider] = gateways.get_sorted_gateways()
defer.returnValue(provider_dict)
@defer.inlineCallbacks
@@ -207,21 +208,7 @@ class VPNService(HookableService):
bonafide = self.parent.getServiceNamed('bonafide')
config = yield bonafide.do_provider_read(provider, 'eip')
- try:
- _cco = self.parent.get_config('vpn_prefs', 'countries', "")
- pref_cco = json.loads(_cco)
- except ValueError:
- pref_cco = []
- try:
- _loc = self.parent.get_config('vpn_prefs', 'locations', "")
- pref_loc = json.loads(_loc)
- except ValueError:
- pref_loc = []
-
- sorted_gateways = GatewaySelector(
- config.gateways, config.locations,
- preferred={'cc': pref_cco, 'loc': pref_loc}
- ).select_gateways()
+ sorted_gateways = self._gateways(config).select_gateways()
extra_flags = config.openvpn_configuration
@@ -244,6 +231,23 @@ class VPNService(HookableService):
provider, remotes, cert_path, key_path, ca_path, extra_flags)
self._firewall = FirewallManager(remotes)
+ def _gateways(self, config):
+ try:
+ _cco = self.parent.get_config('vpn_prefs', 'countries', "")
+ pref_cco = json.loads(_cco)
+ except ValueError:
+ pref_cco = []
+ try:
+ _loc = self.parent.get_config('vpn_prefs', 'locations', "")
+ pref_loc = json.loads(_loc)
+ except ValueError:
+ pref_loc = []
+
+ return GatewaySelector(
+ config.gateways, config.locations,
+ preferred={'cc': pref_cco, 'loc': pref_loc}
+ )
+
def _cert_expires(self, provider):
path = os.path.join(
self._basepath, "leap", "providers", provider,
diff --git a/tests/unit/vpn/test_gateways.py b/tests/unit/vpn/test_gateways.py
index d16e910..eda2a22 100644
--- a/tests/unit/vpn/test_gateways.py
+++ b/tests/unit/vpn/test_gateways.py
@@ -128,6 +128,11 @@ class GatewaySelectorTestCase(unittest.TestCase):
assert gateways == [ips[4], ips[2], ips[3], ips[1]]
def test_apply_user_preferences(self):
+ def to_gateways(gws):
+ return [{'location': x[0], 'ip_address': x[1],
+ 'country_code': x[2]}
+ for x in gws]
+
preferred = {
'loc': ['anarres', 'paris__fr', 'montevideo'],
'cc': ['BR', 'AR', 'UY'],
@@ -138,8 +143,8 @@ class GatewaySelectorTestCase(unittest.TestCase):
('Rio de Janeiro', '1.1.1.1', 'BR'),
('Montevideo', '1.1.1.1', 'UY'),
('Cordoba', '1.1.1.1', 'AR')]
- ordered = selector.apply_user_preferences(pre)
- locations = [x[0] for x in ordered]
+ ordered = selector.apply_user_preferences(to_gateways(pre))
+ locations = [x['location'] for x in ordered]
# first the preferred location, then order by country
assert locations == [
'Montevideo',
@@ -152,8 +157,8 @@ class GatewaySelectorTestCase(unittest.TestCase):
('Montevideo', '', ''),
('Paris, FR', '', ''),
('AnaRreS', '', '')]
- ordered = selector.apply_user_preferences(pre)
- locations = [x[0] for x in ordered]
+ ordered = selector.apply_user_preferences(to_gateways(pre))
+ locations = [x['location'] for x in ordered]
# first the preferred location, then order by country
# (test normalization)
assert locations == [
@@ -167,8 +172,8 @@ class GatewaySelectorTestCase(unittest.TestCase):
('Tacuarembo', '', 'UY'),
('Sao Paulo', '', 'BR'),
('Cordoba', '', 'AR')]
- ordered = selector.apply_user_preferences(pre)
- locations = [x[0] for x in ordered]
+ ordered = selector.apply_user_preferences(to_gateways(pre))
+ locations = [x['location'] for x in ordered]
# no matching location, order by country
assert locations == [
'Rio De Janeiro',