package bonafide import ( "errors" "log" "math/rand" "sort" "strconv" "time" ) const ( maxGateways = 3 ) // Load reflects the fullness metric that menshen returns, if available. type Load struct { gateway *Gateway Fullness float64 Overload bool } // A Gateway is a representation of gateways that is independent of the api version. // If a given physical location offers different transports, they will appear // as separate gateways, so make sure to filter them. type Gateway struct { Host string IPAddress string Location string LocationName string CountryCode string Ports []string Protocols []string Options map[string]string Transport string } /* gatewayDistance is used in the timezone distance fallback */ type gatewayDistance struct { gateway Gateway distance int } type gatewayPool struct { /* available is the unordered list of gateways from eip-service, we use if as source-of-truth for now. */ available []Gateway userChoice string /* byLocation is a map from location to an array of hostnames */ byLocation map[string][]*Gateway /* recommended is an array of hostnames, fetched from the old geoip service. */ recommended []Load /* TODO locations are just used to get the timezone for each gateway. I * think it's easier to just merge that info into the version-agnostic * Gateway, that is passed from the eipService, and do not worry with * the location here */ locations map[string]Location } func (gw Gateway) isTransport(transport string) bool { return transport == "any" || gw.Transport == transport } func (p *gatewayPool) populateLocationList() { for i, gw := range p.available { p.byLocation[gw.Location] = append(p.byLocation[gw.Location], &p.available[i]) } } func (p *gatewayPool) getLocations() []string { c := make([]string, 0) if p == nil || p.byLocation == nil || len(p.byLocation) == 0 { return c } if len(p.byLocation) != 0 { for city := range p.byLocation { c = append(c, city) } } return c } func (p *gatewayPool) isValidLocation(location string) bool { locations := p.getLocations() valid := stringInSlice(location, locations) return valid } /* returns a map of location: fullness for the ui to use */ func (p *gatewayPool) listLocationFullness(transport string) map[string]float64 { locations := p.getLocations() cm := make(map[string]float64) if len(locations) == 0 { return cm } if len(p.recommended) != 0 { for _, gw := range p.recommended { if _, ok := cm[gw.gateway.Location]; ok { continue } cm[gw.gateway.Location] = gw.Fullness } } else { for _, location := range locations { cm[location] = -1 } } return cm } /* this method should only be used if we have no usable menshen list. */ func (p *gatewayPool) getRandomGatewaysByLocation(location, transport string) ([]Gateway, error) { if !p.isValidLocation(location) { return []Gateway{}, errors.New("bonafide: BUG not a valid location: " + location) } gws := p.byLocation[location] if len(gws) == 0 { return []Gateway{}, errors.New("bonafide: BUG no gw for location: " + location) } r := rand.New(rand.NewSource(time.Now().Unix())) r.Shuffle(len(gws), func(i, j int) { gws[i], gws[j] = gws[j], gws[i] }) var gateways []Gateway for _, gw := range gws { if gw.isTransport(transport) { gateways = append(gateways, *gw) } if len(gateways) == maxGateways { break } } if len(gateways) == 0 { return []Gateway{}, errors.New("bonafide: BUG could not find any gateway for that location") } return gateways, nil } func (p *gatewayPool) getGatewaysFromMenshenByLocation(location, transport string) ([]Gateway, error) { if !p.isValidLocation(location) { return []Gateway{}, errors.New("bonafide: BUG not a valid location: " + location) } gws := p.byLocation[location] if len(gws) == 0 { return []Gateway{}, errors.New("bonafide: BUG no gw for location: " + location) } var gateways []Gateway for _, gw := range p.recommended { for _, locatedGw := range gws { if !locatedGw.isTransport(transport) { continue } if locatedGw.Host == gw.gateway.Host { gateways = append(gateways, *locatedGw) break } } if len(gateways) == maxGateways { break } } if len(gateways) == 0 { return []Gateway{}, errors.New("bonafide: BUG could not find any gateway for that location") } return gateways, nil } /* used when we select a hostname in the ui and we want to know the gateway details */ func (p *gatewayPool) getGatewayByHost(host string) (Gateway, error) { for _, gw := range p.available { if gw.Host == host { return gw, nil } } return Gateway{}, errors.New("bonafide: not a valid host name") } /* used when we want to know gateway details after we know what IP openvpn has connected to */ func (p *gatewayPool) getGatewayByIP(ip string) (Gateway, error) { for _, gw := range p.available { if gw.IPAddress == ip { return gw, nil } } return Gateway{}, errors.New("bonafide: not a valid ip address") } /* this perhaps could be made more explicit */ func (p *gatewayPool) setAutomaticChoice() { p.userChoice = "" } /* set a user manual override for gateway location */ func (p *gatewayPool) setUserChoice(location string) error { if !p.isValidLocation(location) { return errors.New("bonafide: not a valid city for gateway choice") } p.userChoice = location return nil } func (p *gatewayPool) isManualLocation() bool { return len(p.userChoice) != 0 } /* set the recommended field from an ordered array. needs to be modified if menshen passed an array of Loads */ func (p *gatewayPool) setRecommendedGateways(geo *geoLocation) { var recommended []Load if len(geo.SortedGateways) != 0 { for _, gw := range geo.SortedGateways { found := false for i := range p.available { if p.available[i].Host == gw.Host { recommendedGw := Load{ Fullness: gw.Fullness, Overload: gw.Overload, gateway: &p.available[i], } recommended = append(recommended, recommendedGw) found = true } } if !found { log.Println("ERROR: invalid host in recommended list of hostnames", gw.Host) return } } } else { // If there is not sorted gatways, it means that the old menshen API is being used // let's use the list of hosts then for _, host := range geo.Gateways { found := false for i := range p.available { if p.available[i].Host == host { recommendedGw := Load{ Fullness: -1, gateway: &p.available[i], } recommended = append(recommended, recommendedGw) found = true } } if !found { log.Println("ERROR: invalid host in recommended list of hostnames", host) return } } } p.recommended = recommended } /* get at most max gateways. the method of picking depends on whether we're * doing manual override, and if we got useful info from menshen */ func (p *gatewayPool) getBest(transport string, tz, max int) ([]Gateway, error) { if p.isManualLocation() { if len(p.recommended) != 0 { return p.getGatewaysFromMenshenByLocation(p.userChoice, transport) } else { return p.getRandomGatewaysByLocation(p.userChoice, transport) } } else if len(p.recommended) != 0 { return p.getGatewaysFromMenshen(transport, max) } else { return p.getGatewaysByTimezone(transport, tz, max) } } func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) { if len(p.recommended) != 0 { return p.getGatewaysFromMenshen(transport, 999) } return p.getGatewaysByTimezone(transport, tz, 999) } /* picks at most max gateways, filtering by transport, from the ordered list menshen returned */ func (p *gatewayPool) getGatewaysFromMenshen(transport string, max int) ([]Gateway, error) { gws := make([]Gateway, 0) for _, gw := range p.recommended { if !gw.gateway.isTransport(transport) { continue } gws = append(gws, *gw.gateway) if len(gws) == max { break } } return gws, nil } /* the old timezone based heuristic, when everything goes wrong */ func (p *gatewayPool) getGatewaysByTimezone(transport string, tzOffsetHours, max int) ([]Gateway, error) { gws := make([]Gateway, 0) gwVector := []gatewayDistance{} for _, gw := range p.available { if !gw.isTransport(transport) { continue } distance := 13 gwOffset, err := strconv.Atoi(p.locations[gw.Location].Timezone) if err != nil { log.Printf("Error sorting gateways: %v", err) return gws, err } distance = tzDistance(tzOffsetHours, gwOffset) gwVector = append(gwVector, gatewayDistance{gw, distance}) } rand.Seed(time.Now().UnixNano()) cmp := func(i, j int) bool { if gwVector[i].distance == gwVector[j].distance { return rand.Intn(2) == 1 } return gwVector[i].distance < gwVector[j].distance } sort.Slice(gwVector, cmp) for _, gw := range gwVector { gws = append(gws, gw.gateway) if len(gws) == max { break } } return gws, nil } func newGatewayPool(eip *eipService) *gatewayPool { p := gatewayPool{} p.available = eip.getGateways() p.locations = eip.Locations p.byLocation = make(map[string][]*Gateway) p.populateLocationList() return &p } func tzDistance(offset1, offset2 int) int { abs := func(x int) int { if x < 0 { return -x } return x } distance := abs(offset1 - offset2) if distance > 12 { distance = 24 - distance } return distance } func stringInSlice(a string, list []string) bool { for _, b := range list { if b == a { return true } } return false }