summaryrefslogtreecommitdiff
path: root/pkg/vpn/bonafide
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/vpn/bonafide')
-rw-r--r--pkg/vpn/bonafide/bonafide.go2
-rw-r--r--pkg/vpn/bonafide/gateways.go12
2 files changed, 9 insertions, 5 deletions
diff --git a/pkg/vpn/bonafide/bonafide.go b/pkg/vpn/bonafide/bonafide.go
index 710ece5..fc1bc95 100644
--- a/pkg/vpn/bonafide/bonafide.go
+++ b/pkg/vpn/bonafide/bonafide.go
@@ -225,7 +225,7 @@ func (b *Bonafide) GetGateways(transport string) ([]Gateway, error) {
}
// GetAllGateways only filters gateways by transport.
-// TODO could pass "any" instead?
+// if "any" is provided it will return all gateways for all transports
func (b *Bonafide) GetAllGateways(transport string) ([]Gateway, error) {
err := b.maybeInitializeEIP()
if err != nil {
diff --git a/pkg/vpn/bonafide/gateways.go b/pkg/vpn/bonafide/gateways.go
index 460d80f..50359e9 100644
--- a/pkg/vpn/bonafide/gateways.go
+++ b/pkg/vpn/bonafide/gateways.go
@@ -59,6 +59,10 @@ type gatewayPool struct {
locations map[string]Location
}
+func (gw Gateway) isTransport(transport string) bool {
+ return transport == "any" || gw.Transport == transport
+}
+
func (p *gatewayPool) populateLocationList() {
for i, gw := range p.available {
p.byLocation[gw.Location] = append(p.byLocation[gw.Location], &p.available[i])
@@ -118,7 +122,7 @@ func (p *gatewayPool) getRandomGatewaysByLocation(location, transport string) ([
var gateways []Gateway
for _, gw := range gws {
- if gw.Transport == transport {
+ if gw.isTransport(transport) {
gateways = append(gateways, *gw)
}
if len(gateways) == maxGateways {
@@ -143,7 +147,7 @@ func (p *gatewayPool) getGatewaysFromMenshenByLocation(location, transport strin
var gateways []Gateway
for _, gw := range p.recommended {
- if gw.gateway.Transport != transport {
+ if !gw.gateway.isTransport(transport) {
continue
}
for _, locatedGw := range gws {
@@ -275,7 +279,7 @@ func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) {
func (p *gatewayPool) getGatewaysFromMenshen(transport string, max int) ([]Gateway, error) {
gws := make([]Gateway, 0)
for _, gw := range p.recommended {
- if gw.gateway.Transport != transport {
+ if !gw.gateway.isTransport(transport) {
continue
}
gws = append(gws, *gw.gateway)
@@ -292,7 +296,7 @@ func (p *gatewayPool) getGatewaysByTimezone(transport string, tzOffsetHours, max
gwVector := []gatewayDistance{}
for _, gw := range p.available {
- if gw.Transport != transport {
+ if !gw.isTransport(transport) {
continue
}
distance := 13