summaryrefslogtreecommitdiff
path: root/vendor/git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel/redialpacketconn.go
blob: cf6a8c9e47953ef93ddb137bee073643216cba85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
package turbotunnel

import (
	"context"
	"errors"
	"net"
	"sync"
	"sync/atomic"
	"time"
)

// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
// other, transient net.PacketConns. RedialPacketConn creates a new
// net.PacketConn by calling a provided dialContext function. Whenever the
// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
// calls the dialContext function again and starts sending and receiving packets
// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
// methods return an error only when the dialContext function returns an error.
//
// RedialPacketConn uses static local and remote addresses that are independent
// of those of any dialed net.PacketConn.
type RedialPacketConn struct {
	localAddr   net.Addr
	remoteAddr  net.Addr
	dialContext func(context.Context) (net.PacketConn, error)
	recvQueue   chan []byte
	sendQueue   chan []byte
	closed      chan struct{}
	closeOnce   sync.Once
	// The first dial error, which causes the clientPacketConn to be
	// closed and is returned from future read/write operations. Compare to
	// the rerr and werr in io.Pipe.
	err atomic.Value
}

// NewQueuePacketConn makes a new RedialPacketConn, with the given static local
// and remote addresses, and dialContext function.
func NewRedialPacketConn(
	localAddr, remoteAddr net.Addr,
	dialContext func(context.Context) (net.PacketConn, error),
) *RedialPacketConn {
	c := &RedialPacketConn{
		localAddr:   localAddr,
		remoteAddr:  remoteAddr,
		dialContext: dialContext,
		recvQueue:   make(chan []byte, queueSize),
		sendQueue:   make(chan []byte, queueSize),
		closed:      make(chan struct{}),
		err:         atomic.Value{},
	}
	go c.dialLoop()
	return c
}

// dialLoop repeatedly calls c.dialContext and passes the resulting
// net.PacketConn to c.exchange. It returns only when c is closed or dialContext
// returns an error.
func (c *RedialPacketConn) dialLoop() {
	ctx, cancel := context.WithCancel(context.Background())
	for {
		select {
		case <-c.closed:
			cancel()
			return
		default:
		}
		conn, err := c.dialContext(ctx)
		if err != nil {
			c.closeWithError(err)
			cancel()
			return
		}
		c.exchange(conn)
		conn.Close()
	}
}

// exchange calls ReadFrom on the given net.PacketConn and places the resulting
// packets in the receive queue, and takes packets from the send queue and calls
// WriteTo on them, making the current net.PacketConn active.
func (c *RedialPacketConn) exchange(conn net.PacketConn) {
	readErrCh := make(chan error)
	writeErrCh := make(chan error)

	go func() {
		defer close(readErrCh)
		for {
			select {
			case <-c.closed:
				return
			case <-writeErrCh:
				return
			default:
			}

			var buf [1500]byte
			n, _, err := conn.ReadFrom(buf[:])
			if err != nil {
				readErrCh <- err
				return
			}
			p := make([]byte, n)
			copy(p, buf[:])
			select {
			case c.recvQueue <- p:
			default: // OK to drop packets.
			}
		}
	}()

	go func() {
		defer close(writeErrCh)
		for {
			select {
			case <-c.closed:
				return
			case <-readErrCh:
				return
			case p := <-c.sendQueue:
				_, err := conn.WriteTo(p, c.remoteAddr)
				if err != nil {
					writeErrCh <- err
					return
				}
			}
		}
	}()

	select {
	case <-readErrCh:
	case <-writeErrCh:
	}
}

// ReadFrom reads a packet from the currently active net.PacketConn. The
// packet's original remote address is replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
	select {
	case <-c.closed:
		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	default:
	}
	select {
	case <-c.closed:
		return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	case buf := <-c.recvQueue:
		return copy(p, buf), c.remoteAddr, nil
	}
}

// WriteTo writes a packet to the currently active net.PacketConn. The addr
// argument is ignored and instead replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	// addr is ignored.
	select {
	case <-c.closed:
		return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
	default:
	}
	buf := make([]byte, len(p))
	copy(buf, p)
	select {
	case c.sendQueue <- buf:
		return len(buf), nil
	default:
		// Drop the outgoing packet if the send queue is full.
		return len(buf), nil
	}
}

// closeWithError unblocks pending operations and makes future operations fail
// with the given error. If err is nil, it becomes errClosedPacketConn.
func (c *RedialPacketConn) closeWithError(err error) error {
	var once bool
	c.closeOnce.Do(func() {
		// Store the error to be returned by future read/write
		// operations.
		if err == nil {
			err = errors.New("operation on closed connection")
		}
		c.err.Store(err)
		close(c.closed)
		once = true
	})
	if !once {
		return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
	}
	return nil
}

// Close unblocks pending operations and makes future operations fail with a
// "closed connection" error.
func (c *RedialPacketConn) Close() error {
	return c.closeWithError(nil)
}

// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }

func (c *RedialPacketConn) SetDeadline(t time.Time) error      { return errNotImplemented }
func (c *RedialPacketConn) SetReadDeadline(t time.Time) error  { return errNotImplemented }
func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }