diff options
Diffstat (limited to 'vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go')
-rw-r--r-- | vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go b/vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go new file mode 100644 index 0000000..cf6a8c9 --- /dev/null +++ b/vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go @@ -0,0 +1,204 @@ +package turbotunnel + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" +) + +// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of +// other, transient net.PacketConns. RedialPacketConn creates a new +// net.PacketConn by calling a provided dialContext function. Whenever the +// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn +// calls the dialContext function again and starts sending and receiving packets +// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo +// methods return an error only when the dialContext function returns an error. +// +// RedialPacketConn uses static local and remote addresses that are independent +// of those of any dialed net.PacketConn. +type RedialPacketConn struct { + localAddr net.Addr + remoteAddr net.Addr + dialContext func(context.Context) (net.PacketConn, error) + recvQueue chan []byte + sendQueue chan []byte + closed chan struct{} + closeOnce sync.Once + // The first dial error, which causes the clientPacketConn to be + // closed and is returned from future read/write operations. Compare to + // the rerr and werr in io.Pipe. + err atomic.Value +} + +// NewQueuePacketConn makes a new RedialPacketConn, with the given static local +// and remote addresses, and dialContext function. +func NewRedialPacketConn( + localAddr, remoteAddr net.Addr, + dialContext func(context.Context) (net.PacketConn, error), +) *RedialPacketConn { + c := &RedialPacketConn{ + localAddr: localAddr, + remoteAddr: remoteAddr, + dialContext: dialContext, + recvQueue: make(chan []byte, queueSize), + sendQueue: make(chan []byte, queueSize), + closed: make(chan struct{}), + err: atomic.Value{}, + } + go c.dialLoop() + return c +} + +// dialLoop repeatedly calls c.dialContext and passes the resulting +// net.PacketConn to c.exchange. It returns only when c is closed or dialContext +// returns an error. +func (c *RedialPacketConn) dialLoop() { + ctx, cancel := context.WithCancel(context.Background()) + for { + select { + case <-c.closed: + cancel() + return + default: + } + conn, err := c.dialContext(ctx) + if err != nil { + c.closeWithError(err) + cancel() + return + } + c.exchange(conn) + conn.Close() + } +} + +// exchange calls ReadFrom on the given net.PacketConn and places the resulting +// packets in the receive queue, and takes packets from the send queue and calls +// WriteTo on them, making the current net.PacketConn active. +func (c *RedialPacketConn) exchange(conn net.PacketConn) { + readErrCh := make(chan error) + writeErrCh := make(chan error) + + go func() { + defer close(readErrCh) + for { + select { + case <-c.closed: + return + case <-writeErrCh: + return + default: + } + + var buf [1500]byte + n, _, err := conn.ReadFrom(buf[:]) + if err != nil { + readErrCh <- err + return + } + p := make([]byte, n) + copy(p, buf[:]) + select { + case c.recvQueue <- p: + default: // OK to drop packets. + } + } + }() + + go func() { + defer close(writeErrCh) + for { + select { + case <-c.closed: + return + case <-readErrCh: + return + case p := <-c.sendQueue: + _, err := conn.WriteTo(p, c.remoteAddr) + if err != nil { + writeErrCh <- err + return + } + } + } + }() + + select { + case <-readErrCh: + case <-writeErrCh: + } +} + +// ReadFrom reads a packet from the currently active net.PacketConn. The +// packet's original remote address is replaced with the RedialPacketConn's own +// remote address. +func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + default: + } + select { + case <-c.closed: + return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + case buf := <-c.recvQueue: + return copy(p, buf), c.remoteAddr, nil + } +} + +// WriteTo writes a packet to the currently active net.PacketConn. The addr +// argument is ignored and instead replaced with the RedialPacketConn's own +// remote address. +func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + // addr is ignored. + select { + case <-c.closed: + return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)} + default: + } + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.sendQueue <- buf: + return len(buf), nil + default: + // Drop the outgoing packet if the send queue is full. + return len(buf), nil + } +} + +// closeWithError unblocks pending operations and makes future operations fail +// with the given error. If err is nil, it becomes errClosedPacketConn. +func (c *RedialPacketConn) closeWithError(err error) error { + var once bool + c.closeOnce.Do(func() { + // Store the error to be returned by future read/write + // operations. + if err == nil { + err = errors.New("operation on closed connection") + } + c.err.Store(err) + close(c.closed) + once = true + }) + if !once { + return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} + } + return nil +} + +// Close unblocks pending operations and makes future operations fail with a +// "closed connection" error. +func (c *RedialPacketConn) Close() error { + return c.closeWithError(nil) +} + +// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn. +func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr } + +func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented } +func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented } +func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented } |