From 663c87c221c42e081b5947e298bc9f0541e6913a Mon Sep 17 00:00:00 2001
From: Ruben Pollan <meskio@sindominio.net>
Date: Wed, 30 Aug 2017 19:54:37 +0200
Subject: [feat] list vpn gateways in the order that they are going to be used

---
 src/leap/bitmask/cli/vpn.py      | 14 +++++++-------
 src/leap/bitmask/vpn/gateways.py | 40 +++++++++++++++++-----------------------
 src/leap/bitmask/vpn/service.py  | 36 ++++++++++++++++++++----------------
 3 files changed, 44 insertions(+), 46 deletions(-)

(limited to 'src')

diff --git a/src/leap/bitmask/cli/vpn.py b/src/leap/bitmask/cli/vpn.py
index 5a9cce9e..bcc2a448 100644
--- a/src/leap/bitmask/cli/vpn.py
+++ b/src/leap/bitmask/cli/vpn.py
@@ -125,10 +125,10 @@ def location_printer(result):
               value + Fore.RESET)
 
     for provider, locations in result.items():
-        for loc in locations.values():
-            location_str = ("[%(country_code)s] %(name)s "
-                            "(UTC%(timezone)s %(hemisphere)s)" % loc)
-            pprint(provider, location_str)
-
-        if not locations.values():
-            pprint(provider, "---")
+        for loc in locations:
+            if 'name' not in loc:
+                pprint(provider, "---")
+            else:
+                location_str = ("[%(country_code)s] %(name)s "
+                                "(UTC%(timezone)s %(hemisphere)s)" % loc)
+                pprint(provider, location_str)
diff --git a/src/leap/bitmask/vpn/gateways.py b/src/leap/bitmask/vpn/gateways.py
index a1be1c15..ac75d9a7 100644
--- a/src/leap/bitmask/vpn/gateways.py
+++ b/src/leap/bitmask/vpn/gateways.py
@@ -77,7 +77,8 @@ class GatewaySelector(object):
         """
         Returns the IPs top 4 preferred gateways, in order.
         """
-        gateways = [gateway[1] for gateway in self.get_sorted_gateways()][:4]
+        gateways = [gateway['ip_address']
+                    for gateway in self.get_sorted_gateways()][:4]
         return gateways
 
     def get_sorted_gateways(self):
@@ -92,28 +93,21 @@ class GatewaySelector(object):
             distance = 99  # if hasn't location -> should go last
             location = locations.get(gateway.get('location'))
 
-            label = gateway.get('location', 'Unknown')
-            country = 'XX'
+            gateway = gateway.copy()
             if location is not None:
-                country = location.get('country_code', 'XX')
-                label = location.get('name', label)
+                gateway.update(location)
                 timezone = location.get('timezone')
                 if timezone is not None:
                     offset = int(timezone)
                     if offset in self.equivalent_timezones:
                         offset = self.equivalent_timezones[offset]
                     distance = self._get_timezone_distance(offset)
-            ip = self.gateways[idx].get('ip_address')
-            gateways_timezones.append((ip, distance, label, country))
+            gateway['distance'] = distance
+            gateways_timezones.append(gateway)
 
-        gateways_timezones = sorted(gateways_timezones, key=lambda gw: gw[1])
-
-        result = []
-        for ip, distance, label, country in gateways_timezones:
-            result.append((label, ip, country))
-
-        filtered = self.apply_user_preferences(result)
-        return filtered
+        gateways_timezones = sorted(gateways_timezones,
+                                    key=lambda gw: gw['distance'])
+        return self.apply_user_preferences(gateways_timezones)
 
     def apply_user_preferences(self, options):
         """
@@ -125,17 +119,17 @@ class GatewaySelector(object):
         applied = []
         presorted = copy.copy(options)
         for location in self.preferred.get('loc', []):
-            for index, data in enumerate(presorted):
-                label, ip, country = data
-                if _normalized(label) == _normalized(location):
-                    applied.append((label, ip, country))
+            for index, gw in enumerate(presorted):
+                if ('location' in gw and
+                        _normalized(gw['location']) == _normalized(location)):
+                    applied.append(gw)
                     presorted.pop(index)
 
         for cc in self.preferred.get('cc', []):
-            for index, data in enumerate(presorted):
-                label, ip, country = data
-                if _normalized(country) == _normalized(cc):
-                    applied.append((label, ip, country))
+            for index, gw in enumerate(presorted):
+                if ('country_code' in gw and
+                        _normalized(gw['country_code']) == _normalized(cc)):
+                    applied.append(gw)
                     presorted.pop(index)
         if presorted:
             applied += presorted
diff --git a/src/leap/bitmask/vpn/service.py b/src/leap/bitmask/vpn/service.py
index f4af3036..6588e1d5 100644
--- a/src/leap/bitmask/vpn/service.py
+++ b/src/leap/bitmask/vpn/service.py
@@ -194,7 +194,8 @@ class VPNService(HookableService):
                 config = yield bonafide.do_provider_read(provider, 'eip')
             except ValueError:
                 continue
-            provider_dict[provider] = config.locations
+            gateways = self._gateways(config)
+            provider_dict[provider] = gateways.get_sorted_gateways()
         defer.returnValue(provider_dict)
 
     @defer.inlineCallbacks
@@ -207,21 +208,7 @@ class VPNService(HookableService):
         bonafide = self.parent.getServiceNamed('bonafide')
         config = yield bonafide.do_provider_read(provider, 'eip')
 
-        try:
-            _cco = self.parent.get_config('vpn_prefs', 'countries', "")
-            pref_cco = json.loads(_cco)
-        except ValueError:
-            pref_cco = []
-        try:
-            _loc = self.parent.get_config('vpn_prefs', 'locations', "")
-            pref_loc = json.loads(_loc)
-        except ValueError:
-            pref_loc = []
-
-        sorted_gateways = GatewaySelector(
-            config.gateways, config.locations,
-            preferred={'cc': pref_cco, 'loc': pref_loc}
-        ).select_gateways()
+        sorted_gateways = self._gateways(config).select_gateways()
 
         extra_flags = config.openvpn_configuration
 
@@ -244,6 +231,23 @@ class VPNService(HookableService):
             provider, remotes, cert_path, key_path, ca_path, extra_flags)
         self._firewall = FirewallManager(remotes)
 
+    def _gateways(self, config):
+        try:
+            _cco = self.parent.get_config('vpn_prefs', 'countries', "")
+            pref_cco = json.loads(_cco)
+        except ValueError:
+            pref_cco = []
+        try:
+            _loc = self.parent.get_config('vpn_prefs', 'locations', "")
+            pref_loc = json.loads(_loc)
+        except ValueError:
+            pref_loc = []
+
+        return GatewaySelector(
+            config.gateways, config.locations,
+            preferred={'cc': pref_cco, 'loc': pref_loc}
+        )
+
     def _cert_expires(self, provider):
         path = os.path.join(
             self._basepath, "leap", "providers", provider,
-- 
cgit v1.2.3