summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/transport/vnet/router.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/transport/vnet/router.go')
-rw-r--r--vendor/github.com/pion/transport/vnet/router.go605
1 files changed, 605 insertions, 0 deletions
diff --git a/vendor/github.com/pion/transport/vnet/router.go b/vendor/github.com/pion/transport/vnet/router.go
new file mode 100644
index 0000000..616d2c9
--- /dev/null
+++ b/vendor/github.com/pion/transport/vnet/router.go
@@ -0,0 +1,605 @@
+package vnet
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pion/logging"
+)
+
+const (
+ defaultRouterQueueSize = 0 // unlimited
+)
+
+var (
+ errInvalidLocalIPinStaticIPs = errors.New("invalid local IP in StaticIPs")
+ errLocalIPBeyondStaticIPsSubset = errors.New("mapped in StaticIPs is beyond subnet")
+ errLocalIPNoStaticsIPsAssociated = errors.New("all StaticIPs must have associated local IPs")
+ errRouterAlreadyStarted = errors.New("router already started")
+ errRouterAlreadyStopped = errors.New("router already stopped")
+ errStaticIPisBeyondSubnet = errors.New("static IP is beyond subnet")
+ errAddressSpaceExhausted = errors.New("address space exhausted")
+ errNoIPAddrEth0 = errors.New("no IP address is assigned for eth0")
+)
+
+// Generate a unique router name
+var assignRouterName = func() func() string { //nolint:gochecknoglobals
+ var routerIDCtr uint64
+
+ return func() string {
+ n := atomic.AddUint64(&routerIDCtr, 1)
+ return fmt.Sprintf("router%d", n)
+ }
+}()
+
+// RouterConfig ...
+type RouterConfig struct {
+ // Name of router. If not specified, a unique name will be assigned.
+ Name string
+ // CIDR notation, like "192.0.2.0/24"
+ CIDR string
+ // StaticIPs is an array of static IP addresses to be assigned for this router.
+ // If no static IP address is given, the router will automatically assign
+ // an IP address.
+ // This will be ignored if this router is the root.
+ StaticIPs []string
+ // StaticIP is deprecated. Use StaticIPs.
+ StaticIP string
+ // Internal queue size
+ QueueSize int
+ // Effective only when this router has a parent router
+ NATType *NATType
+ // Minimum Delay
+ MinDelay time.Duration
+ // Max Jitter
+ MaxJitter time.Duration
+ // Logger factory
+ LoggerFactory logging.LoggerFactory
+}
+
+// NIC is a nework inerface controller that interfaces Router
+type NIC interface {
+ getInterface(ifName string) (*Interface, error)
+ onInboundChunk(c Chunk)
+ getStaticIPs() []net.IP
+ setRouter(r *Router) error
+}
+
+// ChunkFilter is a handler users can add to filter chunks.
+// If the filter returns false, the packet will be dropped.
+type ChunkFilter func(c Chunk) bool
+
+// Router ...
+type Router struct {
+ name string // read-only
+ interfaces []*Interface // read-only
+ ipv4Net *net.IPNet // read-only
+ staticIPs []net.IP // read-only
+ staticLocalIPs map[string]net.IP // read-only,
+ lastID byte // requires mutex [x], used to assign the last digit of IPv4 address
+ queue *chunkQueue // read-only
+ parent *Router // read-only
+ children []*Router // read-only
+ natType *NATType // read-only
+ nat *networkAddressTranslator // read-only
+ nics map[string]NIC // read-only
+ stopFunc func() // requires mutex [x]
+ resolver *resolver // read-only
+ chunkFilters []ChunkFilter // requires mutex [x]
+ minDelay time.Duration // requires mutex [x]
+ maxJitter time.Duration // requires mutex [x]
+ mutex sync.RWMutex // thread-safe
+ pushCh chan struct{} // writer requires mutex
+ loggerFactory logging.LoggerFactory // read-only
+ log logging.LeveledLogger // read-only
+}
+
+// NewRouter ...
+func NewRouter(config *RouterConfig) (*Router, error) {
+ loggerFactory := config.LoggerFactory
+ log := loggerFactory.NewLogger("vnet")
+
+ _, ipv4Net, err := net.ParseCIDR(config.CIDR)
+ if err != nil {
+ return nil, err
+ }
+
+ queueSize := defaultRouterQueueSize
+ if config.QueueSize > 0 {
+ queueSize = config.QueueSize
+ }
+
+ // set up network interface, lo0
+ lo0 := NewInterface(net.Interface{
+ Index: 1,
+ MTU: 16384,
+ Name: lo0String,
+ HardwareAddr: nil,
+ Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast,
+ })
+ lo0.AddAddr(&net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""})
+
+ // set up network interface, eth0
+ eth0 := NewInterface(net.Interface{
+ Index: 2,
+ MTU: 1500,
+ Name: "eth0",
+ HardwareAddr: newMACAddress(),
+ Flags: net.FlagUp | net.FlagMulticast,
+ })
+
+ // local host name resolver
+ resolver := newResolver(&resolverConfig{
+ LoggerFactory: config.LoggerFactory,
+ })
+
+ name := config.Name
+ if len(name) == 0 {
+ name = assignRouterName()
+ }
+
+ var staticIPs []net.IP
+ staticLocalIPs := map[string]net.IP{}
+ for _, ipStr := range config.StaticIPs {
+ ipPair := strings.Split(ipStr, "/")
+ if ip := net.ParseIP(ipPair[0]); ip != nil {
+ if len(ipPair) > 1 {
+ locIP := net.ParseIP(ipPair[1])
+ if locIP == nil {
+ return nil, errInvalidLocalIPinStaticIPs
+ }
+ if !ipv4Net.Contains(locIP) {
+ return nil, fmt.Errorf("local IP %s %w", locIP.String(), errLocalIPBeyondStaticIPsSubset)
+ }
+ staticLocalIPs[ip.String()] = locIP
+ }
+ staticIPs = append(staticIPs, ip)
+ }
+ }
+ if len(config.StaticIP) > 0 {
+ log.Warn("StaticIP is deprecated. Use StaticIPs instead")
+ if ip := net.ParseIP(config.StaticIP); ip != nil {
+ staticIPs = append(staticIPs, ip)
+ }
+ }
+
+ if nStaticLocal := len(staticLocalIPs); nStaticLocal > 0 {
+ if nStaticLocal != len(staticIPs) {
+ return nil, errLocalIPNoStaticsIPsAssociated
+ }
+ }
+
+ return &Router{
+ name: name,
+ interfaces: []*Interface{lo0, eth0},
+ ipv4Net: ipv4Net,
+ staticIPs: staticIPs,
+ staticLocalIPs: staticLocalIPs,
+ queue: newChunkQueue(queueSize),
+ natType: config.NATType,
+ nics: map[string]NIC{},
+ resolver: resolver,
+ minDelay: config.MinDelay,
+ maxJitter: config.MaxJitter,
+ pushCh: make(chan struct{}, 1),
+ loggerFactory: loggerFactory,
+ log: log,
+ }, nil
+}
+
+// caller must hold the mutex
+func (r *Router) getInterfaces() ([]*Interface, error) {
+ if len(r.interfaces) == 0 {
+ return nil, fmt.Errorf("%w is available", errNoInterface)
+ }
+
+ return r.interfaces, nil
+}
+
+func (r *Router) getInterface(ifName string) (*Interface, error) {
+ r.mutex.RLock()
+ defer r.mutex.RUnlock()
+
+ ifs, err := r.getInterfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifc := range ifs {
+ if ifc.Name == ifName {
+ return ifc, nil
+ }
+ }
+
+ return nil, fmt.Errorf("interface %s %w", ifName, errNotFound)
+}
+
+// Start ...
+func (r *Router) Start() error {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ if r.stopFunc != nil {
+ return errRouterAlreadyStarted
+ }
+
+ cancelCh := make(chan struct{})
+
+ go func() {
+ loop:
+ for {
+ d, err := r.processChunks()
+ if err != nil {
+ r.log.Errorf("[%s] %s", r.name, err.Error())
+ break
+ }
+
+ if d <= 0 {
+ select {
+ case <-r.pushCh:
+ case <-cancelCh:
+ break loop
+ }
+ } else {
+ t := time.NewTimer(d)
+ select {
+ case <-t.C:
+ case <-cancelCh:
+ break loop
+ }
+ }
+ }
+ }()
+
+ r.stopFunc = func() {
+ close(cancelCh)
+ }
+
+ for _, child := range r.children {
+ if err := child.Start(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Stop ...
+func (r *Router) Stop() error {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ if r.stopFunc == nil {
+ return errRouterAlreadyStopped
+ }
+
+ for _, router := range r.children {
+ r.mutex.Unlock()
+ err := router.Stop()
+ r.mutex.Lock()
+
+ if err != nil {
+ return err
+ }
+ }
+
+ r.stopFunc()
+ r.stopFunc = nil
+ return nil
+}
+
+// caller must hold the mutex
+func (r *Router) addNIC(nic NIC) error {
+ ifc, err := nic.getInterface("eth0")
+ if err != nil {
+ return err
+ }
+
+ var ips []net.IP
+
+ if ips = nic.getStaticIPs(); len(ips) == 0 {
+ // assign an IP address
+ ip, err2 := r.assignIPAddress()
+ if err2 != nil {
+ return err2
+ }
+ ips = append(ips, ip)
+ }
+
+ for _, ip := range ips {
+ if !r.ipv4Net.Contains(ip) {
+ return fmt.Errorf("%w: %s", errStaticIPisBeyondSubnet, r.ipv4Net.String())
+ }
+
+ ifc.AddAddr(&net.IPNet{
+ IP: ip,
+ Mask: r.ipv4Net.Mask,
+ })
+
+ r.nics[ip.String()] = nic
+ }
+
+ if err = nic.setRouter(r); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// AddRouter adds a chile Router.
+func (r *Router) AddRouter(router *Router) error {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ // Router is a NIC. Add it as a NIC so that packets are routed to this child
+ // router.
+ err := r.addNIC(router)
+ if err != nil {
+ return err
+ }
+
+ if err = router.setRouter(r); err != nil {
+ return err
+ }
+
+ r.children = append(r.children, router)
+ return nil
+}
+
+// AddNet ...
+func (r *Router) AddNet(nic NIC) error {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ return r.addNIC(nic)
+}
+
+// AddHost adds a mapping of hostname and an IP address to the local resolver.
+func (r *Router) AddHost(hostName string, ipAddr string) error {
+ return r.resolver.addHost(hostName, ipAddr)
+}
+
+// AddChunkFilter adds a filter for chunks traversing this router.
+// You may add more than one filter. The filters are called in the order of this method call.
+// If a chunk is dropped by a filter, subsequent filter will not receive the chunk.
+func (r *Router) AddChunkFilter(filter ChunkFilter) {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ r.chunkFilters = append(r.chunkFilters, filter)
+}
+
+// caller should hold the mutex
+func (r *Router) assignIPAddress() (net.IP, error) {
+ // See: https://stackoverflow.com/questions/14915188/ip-address-ending-with-zero
+
+ if r.lastID == 0xfe {
+ return nil, errAddressSpaceExhausted
+ }
+
+ ip := make(net.IP, 4)
+ copy(ip, r.ipv4Net.IP[:3])
+ r.lastID++
+ ip[3] = r.lastID
+ return ip, nil
+}
+
+func (r *Router) push(c Chunk) {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ r.log.Debugf("[%s] route %s", r.name, c.String())
+ if r.stopFunc != nil {
+ c.setTimestamp()
+ if r.queue.push(c) {
+ select {
+ case r.pushCh <- struct{}{}:
+ default:
+ }
+ } else {
+ r.log.Warnf("[%s] queue was full. dropped a chunk", r.name)
+ }
+ }
+}
+
+func (r *Router) processChunks() (time.Duration, error) {
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+
+ // Introduce jitter by delaying the processing of chunks.
+ if r.maxJitter > 0 {
+ jitter := time.Duration(rand.Int63n(int64(r.maxJitter))) //nolint:gosec
+ time.Sleep(jitter)
+ }
+
+ // cutOff
+ // v min delay
+ // |<--->|
+ // +------------:--
+ // |OOOOOOXXXXX : --> time
+ // +------------:--
+ // |<--->| now
+ // due
+
+ enteredAt := time.Now()
+ cutOff := enteredAt.Add(-r.minDelay)
+
+ var d time.Duration // the next sleep duration
+
+ for {
+ d = 0
+
+ c := r.queue.peek()
+ if c == nil {
+ break // no more chunk in the queue
+ }
+
+ // check timestamp to find if the chunk is due
+ if c.getTimestamp().After(cutOff) {
+ // There is one or more chunk in the queue but none of them are due.
+ // Calculate the next sleep duration here.
+ nextExpire := c.getTimestamp().Add(r.minDelay)
+ d = nextExpire.Sub(enteredAt)
+ break
+ }
+
+ var ok bool
+ if c, ok = r.queue.pop(); !ok {
+ break // no more chunk in the queue
+ }
+
+ blocked := false
+ for i := 0; i < len(r.chunkFilters); i++ {
+ filter := r.chunkFilters[i]
+ if !filter(c) {
+ blocked = true
+ break
+ }
+ }
+ if blocked {
+ continue // discard
+ }
+
+ dstIP := c.getDestinationIP()
+
+ // check if the desination is in our subnet
+ if r.ipv4Net.Contains(dstIP) {
+ // search for the destination NIC
+ var nic NIC
+ if nic, ok = r.nics[dstIP.String()]; !ok {
+ // NIC not found. drop it.
+ r.log.Debugf("[%s] %s unreachable", r.name, c.String())
+ continue
+ }
+
+ // found the NIC, forward the chunk to the NIC.
+ // call to NIC must unlock mutex
+ r.mutex.Unlock()
+ nic.onInboundChunk(c)
+ r.mutex.Lock()
+ continue
+ }
+
+ // the destination is outside of this subnet
+ // is this WAN?
+ if r.parent == nil {
+ // this WAN. No route for this chunk
+ r.log.Debugf("[%s] no route found for %s", r.name, c.String())
+ continue
+ }
+
+ // Pass it to the parent via NAT
+ toParent, err := r.nat.translateOutbound(c)
+ if err != nil {
+ return 0, err
+ }
+
+ if toParent == nil {
+ continue
+ }
+
+ //nolint:godox
+ /* FIXME: this implementation would introduce a duplicate packet!
+ if r.nat.natType.Hairpining {
+ hairpinned, err := r.nat.translateInbound(toParent)
+ if err != nil {
+ r.log.Warnf("[%s] %s", r.name, err.Error())
+ } else {
+ go func() {
+ r.push(hairpinned)
+ }()
+ }
+ }
+ */
+
+ // call to parent router mutex unlock mutex
+ r.mutex.Unlock()
+ r.parent.push(toParent)
+ r.mutex.Lock()
+ }
+
+ return d, nil
+}
+
+// caller must hold the mutex
+func (r *Router) setRouter(parent *Router) error {
+ r.parent = parent
+ r.resolver.setParent(parent.resolver)
+
+ // when this method is called, one or more IP address has already been assigned by
+ // the parent router.
+ ifc, err := r.getInterface("eth0")
+ if err != nil {
+ return err
+ }
+
+ if len(ifc.addrs) == 0 {
+ return errNoIPAddrEth0
+ }
+
+ mappedIPs := []net.IP{}
+ localIPs := []net.IP{}
+
+ for _, ifcAddr := range ifc.addrs {
+ var ip net.IP
+ switch addr := ifcAddr.(type) {
+ case *net.IPNet:
+ ip = addr.IP
+ case *net.IPAddr: // Do we really need this case?
+ ip = addr.IP
+ default:
+ }
+
+ if ip == nil {
+ continue
+ }
+
+ mappedIPs = append(mappedIPs, ip)
+
+ if locIP := r.staticLocalIPs[ip.String()]; locIP != nil {
+ localIPs = append(localIPs, locIP)
+ }
+ }
+
+ // Set up NAT here
+ if r.natType == nil {
+ r.natType = &NATType{
+ MappingBehavior: EndpointIndependent,
+ FilteringBehavior: EndpointAddrPortDependent,
+ Hairpining: false,
+ PortPreservation: false,
+ MappingLifeTime: 30 * time.Second,
+ }
+ }
+ r.nat, err = newNAT(&natConfig{
+ name: r.name,
+ natType: *r.natType,
+ mappedIPs: mappedIPs,
+ localIPs: localIPs,
+ loggerFactory: r.loggerFactory,
+ })
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (r *Router) onInboundChunk(c Chunk) {
+ fromParent, err := r.nat.translateInbound(c)
+ if err != nil {
+ r.log.Warnf("[%s] %s", r.name, err.Error())
+ return
+ }
+
+ r.push(fromParent)
+}
+
+func (r *Router) getStaticIPs() []net.IP {
+ return r.staticIPs
+}