diff options
Diffstat (limited to 'vendor/github.com/pion/transport/vnet/net.go')
-rw-r--r-- | vendor/github.com/pion/transport/vnet/net.go | 677 |
1 files changed, 677 insertions, 0 deletions
diff --git a/vendor/github.com/pion/transport/vnet/net.go b/vendor/github.com/pion/transport/vnet/net.go new file mode 100644 index 0000000..4dc6a2a --- /dev/null +++ b/vendor/github.com/pion/transport/vnet/net.go @@ -0,0 +1,677 @@ +package vnet + +import ( + "encoding/binary" + "errors" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "sync" +) + +const ( + lo0String = "lo0String" + udpString = "udp" +) + +var ( + macAddrCounter uint64 = 0xBEEFED910200 //nolint:gochecknoglobals + errNoInterface = errors.New("no interface is available") + errNotFound = errors.New("not found") + errUnexpectedNetwork = errors.New("unexpected network") + errCantAssignRequestedAddr = errors.New("can't assign requested address") + errUnknownNetwork = errors.New("unknown network") + errNoRouterLinked = errors.New("no router linked") + errInvalidPortNumber = errors.New("invalid port number") + errUnexpectedTypeSwitchFailure = errors.New("unexpected type-switch failure") + errBindFailerFor = errors.New("bind failed for") + errEndPortLessThanStart = errors.New("end port is less than the start") + errPortSpaceExhausted = errors.New("port space exhausted") + errVNetDisabled = errors.New("vnet is not enabled") +) + +func newMACAddress() net.HardwareAddr { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, macAddrCounter) + macAddrCounter++ + return b[2:] +} + +type vNet struct { + interfaces []*Interface // read-only + staticIPs []net.IP // read-only + router *Router // read-only + udpConns *udpConnMap // read-only + mutex sync.RWMutex +} + +func (v *vNet) _getInterfaces() ([]*Interface, error) { + if len(v.interfaces) == 0 { + return nil, errNoInterface + } + + return v.interfaces, nil +} + +func (v *vNet) getInterfaces() ([]*Interface, error) { + v.mutex.RLock() + defer v.mutex.RUnlock() + + return v._getInterfaces() +} + +// caller must hold the mutex (read) +func (v *vNet) _getInterface(ifName string) (*Interface, error) { + ifs, err := v._getInterfaces() + if err != nil { + return nil, err + } + for _, ifc := range ifs { + if ifc.Name == ifName { + return ifc, nil + } + } + + return nil, fmt.Errorf("interface %s %w", ifName, errNotFound) +} + +func (v *vNet) getInterface(ifName string) (*Interface, error) { + v.mutex.RLock() + defer v.mutex.RUnlock() + + return v._getInterface(ifName) +} + +// caller must hold the mutex +func (v *vNet) getAllIPAddrs(ipv6 bool) []net.IP { + ips := []net.IP{} + + for _, ifc := range v.interfaces { + addrs, err := ifc.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + if ipNet, ok := addr.(*net.IPNet); ok { + ip = ipNet.IP + } else if ipAddr, ok := addr.(*net.IPAddr); ok { + ip = ipAddr.IP + } else { + continue + } + + if !ipv6 { + if ip.To4() != nil { + ips = append(ips, ip) + } + } + } + } + + return ips +} + +func (v *vNet) setRouter(r *Router) error { + v.mutex.Lock() + defer v.mutex.Unlock() + + v.router = r + return nil +} + +func (v *vNet) onInboundChunk(c Chunk) { + v.mutex.Lock() + defer v.mutex.Unlock() + + if c.Network() == udpString { + if conn, ok := v.udpConns.find(c.DestinationAddr()); ok { + conn.onInboundChunk(c) + } + } +} + +// caller must hold the mutex +func (v *vNet) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) { + // validate network + if network != udpString && network != "udp4" { + return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network) + } + + if locAddr == nil { + locAddr = &net.UDPAddr{ + IP: net.IPv4zero, + } + } else if locAddr.IP == nil { + locAddr.IP = net.IPv4zero + } + + // validate address. do we have that address? + if !v.hasIPAddr(locAddr.IP) { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: fmt.Errorf("bind: %w", errCantAssignRequestedAddr), + } + } + + if locAddr.Port == 0 { + // choose randomly from the range between 5000 and 5999 + port, err := v.assignPort(locAddr.IP, 5000, 5999) + if err != nil { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: err, + } + } + locAddr.Port = port + } else if _, ok := v.udpConns.find(locAddr); ok { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), + } + } + + conn, err := newUDPConn(locAddr, remAddr, v) + if err != nil { + return nil, err + } + + err = v.udpConns.insert(conn) + if err != nil { + return nil, err + } + + return conn, nil +} + +func (v *vNet) listenPacket(network string, address string) (UDPPacketConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + locAddr, err := v.resolveUDPAddr(network, address) + if err != nil { + return nil, err + } + + return v._dialUDP(network, locAddr, nil) +} + +func (v *vNet) listenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + return v._dialUDP(network, locAddr, nil) +} + +func (v *vNet) dialUDP(network string, locAddr, remAddr *net.UDPAddr) (UDPPacketConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + return v._dialUDP(network, locAddr, remAddr) +} + +func (v *vNet) dial(network string, address string) (UDPPacketConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + remAddr, err := v.resolveUDPAddr(network, address) + if err != nil { + return nil, err + } + + // Determine source address + srcIP := v.determineSourceIP(nil, remAddr.IP) + + locAddr := &net.UDPAddr{IP: srcIP, Port: 0} + + return v._dialUDP(network, locAddr, remAddr) +} + +func (v *vNet) resolveUDPAddr(network, address string) (*net.UDPAddr, error) { + if network != udpString && network != "udp4" { + return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) + } + + host, sPort, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + // Check if host is a domain name + ip := net.ParseIP(host) + if ip == nil { + host = strings.ToLower(host) + if host == "localhost" { + ip = net.IPv4(127, 0, 0, 1) + } else { + // host is a domain name. resolve IP address by the name + if v.router == nil { + return nil, errNoRouterLinked + } + + ip, err = v.router.resolver.lookUp(host) + if err != nil { + return nil, err + } + } + } + + port, err := strconv.Atoi(sPort) + if err != nil { + return nil, errInvalidPortNumber + } + + udpAddr := &net.UDPAddr{ + IP: ip, + Port: port, + } + + return udpAddr, nil +} + +func (v *vNet) write(c Chunk) error { + if c.Network() == udpString { + if udp, ok := c.(*chunkUDP); ok { + if c.getDestinationIP().IsLoopback() { + if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok { + conn.onInboundChunk(udp) + } + return nil + } + } else { + return errUnexpectedTypeSwitchFailure + } + } + + if v.router == nil { + return errNoRouterLinked + } + + v.router.push(c) + return nil +} + +func (v *vNet) onClosed(addr net.Addr) { + if addr.Network() == udpString { + //nolint:errcheck + v.udpConns.delete(addr) // #nosec + } +} + +// This method determines the srcIP based on the dstIP when locIP +// is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr, +// this method simply returns locIP. +// caller must hold the mutex +func (v *vNet) determineSourceIP(locIP, dstIP net.IP) net.IP { + if locIP != nil && !locIP.IsUnspecified() { + return locIP + } + + var srcIP net.IP + + if dstIP.IsLoopback() { + srcIP = net.ParseIP("127.0.0.1") + } else { + ifc, err2 := v._getInterface("eth0") + if err2 != nil { + return nil + } + + addrs, err2 := ifc.Addrs() + if err2 != nil { + return nil + } + + if len(addrs) == 0 { + return nil + } + + var findIPv4 bool + if locIP != nil { + findIPv4 = (locIP.To4() != nil) + } else { + findIPv4 = (dstIP.To4() != nil) + } + + for _, addr := range addrs { + ip := addr.(*net.IPNet).IP + if findIPv4 { + if ip.To4() != nil { + srcIP = ip + break + } + } else { + if ip.To4() == nil { + srcIP = ip + break + } + } + } + } + + return srcIP +} + +// caller must hold the mutex +func (v *vNet) hasIPAddr(ip net.IP) bool { //nolint:gocognit + for _, ifc := range v.interfaces { + if addrs, err := ifc.Addrs(); err == nil { + for _, addr := range addrs { + var locIP net.IP + if ipNet, ok := addr.(*net.IPNet); ok { + locIP = ipNet.IP + } else if ipAddr, ok := addr.(*net.IPAddr); ok { + locIP = ipAddr.IP + } else { + continue + } + + switch ip.String() { + case "0.0.0.0": + if locIP.To4() != nil { + return true + } + case "::": + if locIP.To4() == nil { + return true + } + default: + if locIP.Equal(ip) { + return true + } + } + } + } + } + + return false +} + +// caller must hold the mutex +func (v *vNet) allocateLocalAddr(ip net.IP, port int) error { + // gather local IP addresses to bind + var ips []net.IP + if ip.IsUnspecified() { + ips = v.getAllIPAddrs(ip.To4() == nil) + } else if v.hasIPAddr(ip) { + ips = []net.IP{ip} + } + + if len(ips) == 0 { + return fmt.Errorf("%w %s", errBindFailerFor, ip.String()) + } + + // check if all these transport addresses are not in use + for _, ip2 := range ips { + addr := &net.UDPAddr{ + IP: ip2, + Port: port, + } + if _, ok := v.udpConns.find(addr); ok { + return &net.OpError{ + Op: "bind", + Net: udpString, + Addr: addr, + Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), + } + } + } + + return nil +} + +// caller must hold the mutex +func (v *vNet) assignPort(ip net.IP, start, end int) (int, error) { + // choose randomly from the range between start and end (inclusive) + if end < start { + return -1, errEndPortLessThanStart + } + + space := end + 1 - start + offset := rand.Intn(space) //nolint:gosec + for i := 0; i < space; i++ { + port := ((offset + i) % space) + start + + err := v.allocateLocalAddr(ip, port) + if err == nil { + return port, nil + } + } + + return -1, errPortSpaceExhausted +} + +// NetConfig is a bag of configuration parameters passed to NewNet(). +type NetConfig struct { + // StaticIPs is an array of static IP addresses to be assigned for this Net. + // If no static IP address is given, the router will automatically assign + // an IP address. + StaticIPs []string + + // StaticIP is deprecated. Use StaticIPs. + StaticIP string +} + +// Net represents a local network stack euivalent to a set of layers from NIC +// up to the transport (UDP / TCP) layer. +type Net struct { + v *vNet + ifs []*Interface +} + +// NewNet creates an instance of Net. +// If config is nil, the virtual network is disabled. (uses corresponding +// net.Xxxx() operations. +// By design, it always have lo0 and eth0 interfaces. +// The lo0 has the address 127.0.0.1 assigned by default. +// IP address for eth0 will be assigned when this Net is added to a router. +func NewNet(config *NetConfig) *Net { + if config == nil { + ifs := []*Interface{} + if orgIfs, err := net.Interfaces(); err == nil { + for _, orgIfc := range orgIfs { + ifc := NewInterface(orgIfc) + if addrs, err := orgIfc.Addrs(); err == nil { + for _, addr := range addrs { + ifc.AddAddr(addr) + } + } + + ifs = append(ifs, ifc) + } + } + + return &Net{ifs: ifs} + } + + lo0 := NewInterface(net.Interface{ + Index: 1, + MTU: 16384, + Name: lo0String, + HardwareAddr: nil, + Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast, + }) + lo0.AddAddr(&net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(8, 32), + }) + + eth0 := NewInterface(net.Interface{ + Index: 2, + MTU: 1500, + Name: "eth0", + HardwareAddr: newMACAddress(), + Flags: net.FlagUp | net.FlagMulticast, + }) + + var staticIPs []net.IP + for _, ipStr := range config.StaticIPs { + if ip := net.ParseIP(ipStr); ip != nil { + staticIPs = append(staticIPs, ip) + } + } + if len(config.StaticIP) > 0 { + if ip := net.ParseIP(config.StaticIP); ip != nil { + staticIPs = append(staticIPs, ip) + } + } + + v := &vNet{ + interfaces: []*Interface{lo0, eth0}, + staticIPs: staticIPs, + udpConns: newUDPConnMap(), + } + + return &Net{ + v: v, + } +} + +// Interfaces returns a list of the system's network interfaces. +func (n *Net) Interfaces() ([]*Interface, error) { + if n.v == nil { + return n.ifs, nil + } + + return n.v.getInterfaces() +} + +// InterfaceByName returns the interface specified by name. +func (n *Net) InterfaceByName(name string) (*Interface, error) { + if n.v == nil { + for _, ifc := range n.ifs { + if ifc.Name == name { + return ifc, nil + } + } + + return nil, fmt.Errorf("interface %s %w", name, errNotFound) + } + + return n.v.getInterface(name) +} + +// ListenPacket announces on the local network address. +func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) { + if n.v == nil { + return net.ListenPacket(network, address) + } + + return n.v.listenPacket(network, address) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (UDPPacketConn, error) { + if n.v == nil { + return net.ListenUDP(network, locAddr) + } + + return n.v.listenUDP(network, locAddr) +} + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + if n.v == nil { + return net.Dial(network, address) + } + + return n.v.dial(network, address) +} + +// CreateDialer creates an instance of vnet.Dialer +func (n *Net) CreateDialer(dialer *net.Dialer) Dialer { + if n.v == nil { + return &vDialer{ + dialer: dialer, + } + } + + return &vDialer{ + dialer: dialer, + v: n.v, + } +} + +// DialUDP acts like Dial for UDP networks. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPPacketConn, error) { + if n.v == nil { + return net.DialUDP(network, laddr, raddr) + } + + return n.v.dialUDP(network, laddr, raddr) +} + +// ResolveUDPAddr returns an address of UDP end point. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + if n.v == nil { + return net.ResolveUDPAddr(network, address) + } + + return n.v.resolveUDPAddr(network, address) +} + +func (n *Net) getInterface(ifName string) (*Interface, error) { + if n.v == nil { + return nil, errVNetDisabled + } + + return n.v.getInterface(ifName) +} + +func (n *Net) setRouter(r *Router) error { + if n.v == nil { + return errVNetDisabled + } + + return n.v.setRouter(r) +} + +func (n *Net) onInboundChunk(c Chunk) { + if n.v == nil { + return + } + + n.v.onInboundChunk(c) +} + +func (n *Net) getStaticIPs() []net.IP { + if n.v == nil { + return nil + } + + return n.v.staticIPs +} + +// IsVirtual tests if the virtual network is enabled. +func (n *Net) IsVirtual() bool { + return n.v != nil +} + +// Dialer is identical to net.Dialer excepts that its methods +// (Dial, DialContext) are overridden to use virtual network. +// Use vnet.CreateDialer() to create an instance of this Dialer. +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +type vDialer struct { + dialer *net.Dialer + v *vNet +} + +func (d *vDialer) Dial(network, address string) (net.Conn, error) { + if d.v == nil { + return d.dialer.Dial(network, address) + } + + return d.v.dial(network, address) +} |