diff options
-rw-r--r-- | pkg/backend/api.go | 19 | ||||
-rw-r--r-- | pkg/bitmask/bitmask.go | 1 | ||||
-rw-r--r-- | pkg/vpn/bonafide/bonafide.go | 2 | ||||
-rw-r--r-- | pkg/vpn/bonafide/gateways.go | 12 | ||||
-rw-r--r-- | pkg/vpn/main.go | 5 | ||||
-rw-r--r-- | pkg/vpn/openvpn.go | 77 |
6 files changed, 72 insertions, 44 deletions
diff --git a/pkg/backend/api.go b/pkg/backend/api.go index d43751b..60d51f3 100644 --- a/pkg/backend/api.go +++ b/pkg/backend/api.go @@ -7,7 +7,7 @@ import ( "encoding/json" "log" "strconv" - "time" + "strings" "unsafe" "0xacab.org/leap/bitmask-vpn/pkg/bitmask" @@ -62,8 +62,8 @@ func UseLocation(label string) { ctx.bm.UseGateway(label) go trigger(OnStatusChanged) - if label != ctx.CurrentLocation { - reconnect() + if ctx.Status == on && label != strings.ToLower(ctx.CurrentLocation) { + ctx.bm.Reconnect() } } @@ -74,16 +74,9 @@ func UseAutomaticGateway() { ctx.bm.UseAutomaticGateway() go trigger(OnStatusChanged) - reconnect() -} - -// TODO implement Reconnect - do not tear whole fw down in between - -func reconnect() { - time.Sleep(200 * time.Millisecond) - SwitchOff() - time.Sleep(500 * time.Millisecond) - SwitchOn() + if ctx.Status == on { + ctx.bm.Reconnect() + } } func UseTransport(label string) { diff --git a/pkg/bitmask/bitmask.go b/pkg/bitmask/bitmask.go index d4c1c3f..3f484e8 100644 --- a/pkg/bitmask/bitmask.go +++ b/pkg/bitmask/bitmask.go @@ -22,6 +22,7 @@ type Bitmask interface { StartVPN(provider string) error CanStartVPN() bool StopVPN() error + Reconnect() error ReloadFirewall() error GetStatus() (string, error) InstallHelpers() error diff --git a/pkg/vpn/bonafide/bonafide.go b/pkg/vpn/bonafide/bonafide.go index 710ece5..fc1bc95 100644 --- a/pkg/vpn/bonafide/bonafide.go +++ b/pkg/vpn/bonafide/bonafide.go @@ -225,7 +225,7 @@ func (b *Bonafide) GetGateways(transport string) ([]Gateway, error) { } // GetAllGateways only filters gateways by transport. -// TODO could pass "any" instead? +// if "any" is provided it will return all gateways for all transports func (b *Bonafide) GetAllGateways(transport string) ([]Gateway, error) { err := b.maybeInitializeEIP() if err != nil { diff --git a/pkg/vpn/bonafide/gateways.go b/pkg/vpn/bonafide/gateways.go index 460d80f..50359e9 100644 --- a/pkg/vpn/bonafide/gateways.go +++ b/pkg/vpn/bonafide/gateways.go @@ -59,6 +59,10 @@ type gatewayPool struct { locations map[string]Location } +func (gw Gateway) isTransport(transport string) bool { + return transport == "any" || gw.Transport == transport +} + func (p *gatewayPool) populateLocationList() { for i, gw := range p.available { p.byLocation[gw.Location] = append(p.byLocation[gw.Location], &p.available[i]) @@ -118,7 +122,7 @@ func (p *gatewayPool) getRandomGatewaysByLocation(location, transport string) ([ var gateways []Gateway for _, gw := range gws { - if gw.Transport == transport { + if gw.isTransport(transport) { gateways = append(gateways, *gw) } if len(gateways) == maxGateways { @@ -143,7 +147,7 @@ func (p *gatewayPool) getGatewaysFromMenshenByLocation(location, transport strin var gateways []Gateway for _, gw := range p.recommended { - if gw.gateway.Transport != transport { + if !gw.gateway.isTransport(transport) { continue } for _, locatedGw := range gws { @@ -275,7 +279,7 @@ func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) { func (p *gatewayPool) getGatewaysFromMenshen(transport string, max int) ([]Gateway, error) { gws := make([]Gateway, 0) for _, gw := range p.recommended { - if gw.gateway.Transport != transport { + if !gw.gateway.isTransport(transport) { continue } gws = append(gws, *gw.gateway) @@ -292,7 +296,7 @@ func (p *gatewayPool) getGatewaysByTimezone(transport string, tzOffsetHours, max gwVector := []gatewayDistance{} for _, gw := range p.available { - if gw.Transport != transport { + if !gw.isTransport(transport) { continue } distance := 13 diff --git a/pkg/vpn/main.go b/pkg/vpn/main.go index 29b843b..5f5117a 100644 --- a/pkg/vpn/main.go +++ b/pkg/vpn/main.go @@ -36,6 +36,8 @@ type Bitmask struct { launch *launcher transport string shapes *shapeshifter.ShapeShifter + certPemPath string + openvpnArgs []string } // Init the connection to bitmask @@ -50,8 +52,9 @@ func Init() (*Bitmask, error) { if err != nil { return nil, err } - b := Bitmask{tempdir, bonafide.Gateway{}, statusCh, nil, bf, launch, "", nil} + b := Bitmask{tempdir, bonafide.Gateway{}, statusCh, nil, bf, launch, "", nil, "", []string{}} + b.launch.firewallStop() /* TODO -- we still want to do this, since it resets the fw/vpn if running from a previous one, but first we need to complete all the diff --git a/pkg/vpn/openvpn.go b/pkg/vpn/openvpn.go index 2304dbd..2e552a1 100644 --- a/pkg/vpn/openvpn.go +++ b/pkg/vpn/openvpn.go @@ -35,20 +35,21 @@ const ( // StartVPN for provider func (b *Bitmask) StartVPN(provider string) error { - var proxy string - if b.transport != "" { - var err error - proxy, err = b.startTransport() - if err != nil { - return err - } - } - if !b.CanStartVPN() { return errors.New("BUG: cannot start vpn") } - err := b.startOpenVPN(proxy) - return err + + var err error + b.certPemPath, err = b.getCert() + if err != nil { + return err + } + b.openvpnArgs, err = b.bonafide.GetOpenvpnArgs() + if err != nil { + return err + } + + return b.startOpenVPN() } func (b *Bitmask) CanStartVPN() bool { @@ -110,17 +111,9 @@ func (b *Bitmask) listenShapeErr() { } } -func (b *Bitmask) startOpenVPN(proxy string) error { - certPemPath, err := b.getCert() - if err != nil { - return err - } - arg, err := b.bonafide.GetOpenvpnArgs() - if err != nil { - return err - } - - if proxy == "" { +func (b *Bitmask) startOpenVPN() error { + arg := b.openvpnArgs + if b.transport == "" { gateways, err := b.bonafide.GetGateways("openvpn") if err != nil { return err @@ -136,6 +129,11 @@ func (b *Bitmask) startOpenVPN(proxy string) error { } } } else { + proxy, err := b.startTransport() + if err != nil { + return err + } + gateways, err := b.bonafide.GetGateways(b.transport) if err != nil { return err @@ -153,8 +151,8 @@ func (b *Bitmask) startOpenVPN(proxy string) error { "--management-client", "--management", openvpnManagementAddr, openvpnManagementPort, "--ca", b.getCaCertPath(), - "--cert", certPemPath, - "--key", certPemPath) + "--cert", b.certPemPath, + "--key", b.certPemPath) return b.launch.openvpnStart(arg...) } @@ -186,6 +184,35 @@ func (b *Bitmask) StopVPN() error { return b.launch.openvpnStop() } +// Reconnect to the VPN +func (b *Bitmask) Reconnect() error { + if !b.CanStartVPN() { + return errors.New("BUG: cannot start vpn") + } + + status, err := b.GetStatus() + if err != nil { + return err + } + log.Println("reconnect") + if status != Off { + if b.shapes != nil { + b.shapes.Close() + b.shapes = nil + } + err = b.launch.openvpnStop() + if err != nil { + return err + } + } + + err = b.launch.firewallStop() + if err != nil { + return err + } + return b.startOpenVPN() +} + // ReloadFirewall restarts the firewall func (b *Bitmask) ReloadFirewall() error { err := b.launch.firewallStop() @@ -199,7 +226,7 @@ func (b *Bitmask) ReloadFirewall() error { } if status != Off { - gateways, err := b.bonafide.GetGateways("openvpn") + gateways, err := b.bonafide.GetAllGateways("any") if err != nil { return err } |