summaryrefslogtreecommitdiff
path: root/quicwrapper/dialer.go
blob: 554bb14cf7415da48fdf2dce1aa8930eb53f5574 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
package quicwrapper

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"net"
	"sync"

	"github.com/apex/log"
	"github.com/getlantern/netx"
	quic "github.com/lucas-clemente/quic-go"
)

// a QuicDialFn is a function that may be used to establish a new QUIC Session
type QuicDialFn func(ctx context.Context, addr string, tlsConf *tls.Config, config *quic.Config) (quic.Connection, error)
type UDPDialFn func(addr string) (net.PacketConn, *net.UDPAddr, error)

var (
	DialWithNetx    QuicDialFn = newDialerWithUDPDialer(DialUDPNetx)
	DialWithoutNetx QuicDialFn = quic.DialAddrContext
	defaultQuicDial QuicDialFn = DialWithNetx
)

type wrappedSession struct {
	quic.Connection
	conn net.PacketConn
}

func (w wrappedSession) CloseWithError(code quic.ApplicationErrorCode, mesg string) error {
	err := w.Connection.CloseWithError(code, mesg)
	err2 := w.conn.Close()
	if err == nil {
		err = err2
	}
	return err
}

// Creates a new QuicDialFn that uses the UDPDialFn given to
// create the underlying net.PacketConn
func newDialerWithUDPDialer(dial UDPDialFn) QuicDialFn {
	return func(ctx context.Context, addr string, tlsConf *tls.Config, config *quic.Config) (quic.Connection, error) {
		udpConn, udpAddr, err := dial(addr)
		if err != nil {
			return nil, err
		}
		ses, err := quic.DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
		if err != nil {
			udpConn.Close()
			return nil, err
		}
		return wrappedSession{ses, udpConn}, nil
	}
}

// DialUDPNetx is a UDPDialFn that resolves addresses and obtains
// the net.PacketConn using the netx package.
func DialUDPNetx(addr string) (net.PacketConn, *net.UDPAddr, error) {
	udpAddr, err := netx.ResolveUDPAddr("udp", addr)
	if err != nil {
		return nil, nil, err
	}
	udpConn, err := netx.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
	if err != nil {
		return nil, nil, err
	}
	return udpConn, udpAddr, nil
}

// NewClient returns a client that creates multiplexed
// QUIC connections in a single Session with the given address using
// the provided configuration.
//
// The Session is created using the
// QuicDialFn given, but is not established until
// the first call to Dial(), DialContext() or Connect()
//
// if dial is nil, the default quic dialer is used
func NewClient(addr string, tlsConf *tls.Config, config *Config, dial QuicDialFn) *Client {
	return NewClientWithPinnedCert(addr, tlsConf, config, dial, nil)
}

// NewClientWithPinnedCert returns a new client configured
// as with NewClient, but accepting only a specific given
// certificate.  If the certificate presented by the connected
// server does match the given certificate, the connection is
// rejected. This check is performed regardless of tls.Config
// settings (ie even if InsecureSkipVerify is true)
//
// If a nil certificate is given, the check is not performed and
// any valid certificate according the tls.Config given is accepted
// (equivalent to NewClient behavior)
func NewClientWithPinnedCert(addr string, tlsConf *tls.Config, config *Config, dial QuicDialFn, cert *x509.Certificate) *Client {
	if dial == nil {
		dial = defaultQuicDial
	}

	tlsConf = defaultNextProtos(tlsConf, DefaultClientProtos)

	return &Client{
		session:    nil,
		address:    addr,
		tlsConf:    tlsConf,
		config:     config,
		dial:       dial,
		pinnedCert: cert,
	}

}

type Client struct {
	session    quic.Connection
	muSession  sync.Mutex
	address    string
	tlsConf    *tls.Config
	pinnedCert *x509.Certificate
	config     *Config
	dial       QuicDialFn
}

// DialContext creates a new multiplexed QUIC connection to the
// server configured in the client. The given Context governs
// cancellation / timeout.  If initial handshaking is performed,
// the operation is additionally governed by HandshakeTimeout
// value given in the client Config.
func (c *Client) DialContext(ctx context.Context) (*Conn, error) {
	session, err := c.getOrCreateSession(ctx)
	if err != nil {
		return nil, fmt.Errorf("connecting session: %w", err)
	}
	stream, err := session.OpenStreamSync(ctx)
	if err != nil {
		if ne, ok := err.(net.Error); ok && !ne.Temporary() {
			// start over again when seeing unrecoverable error.
			c.clearSession(err.Error())
		}
		return nil, fmt.Errorf("establishing stream: %w", err)
	}
	return newConn(stream, session, nil), nil
}

// Dial creates a new multiplexed QUIC connection to the
// server configured for the client.
func (c *Client) Dial() (*Conn, error) {
	return c.DialContext(context.Background())
}

// Connect requests immediate handshaking regardless of
// whether any specific Dial has been initiated. It is
// called lazily on the first Dial if not otherwise
// called.
//
// This can serve to pre-establish a multiplexed
// session, but will also initiate idle timeout
// tracking, keepalives etc. Returns any error
// encountered during handshake.
//
// This may safely be called concurrently with Dial.
// The handshake is guaranteed to be completed when the
// call returns to any caller.
func (c *Client) Connect(ctx context.Context) error {
	_, err := c.getOrCreateSession(ctx)
	return err
}

func (c *Client) getOrCreateSession(ctx context.Context) (quic.Connection, error) {
	c.muSession.Lock()
	defer c.muSession.Unlock()
	if c.session == nil {
		session, err := c.dial(ctx, c.address, c.tlsConf, c.config)
		if err != nil {
			return nil, err
		}
		if c.pinnedCert != nil {
			if err = c.verifyPinnedCert(session); err != nil {
				session.CloseWithError(0, "")
				return nil, err
			}
		}
		c.session = session
	}
	return c.session, nil
}

func (c *Client) verifyPinnedCert(session quic.Connection) error {
	certs := session.ConnectionState().TLS.PeerCertificates
	if len(certs) == 0 {
		return fmt.Errorf("Server did not present any certificates!")
	}

	serverCert := certs[0]
	if !serverCert.Equal(c.pinnedCert) {
		received := pem.EncodeToMemory(&pem.Block{
			Type:    "CERTIFICATE",
			Headers: nil,
			Bytes:   serverCert.Raw,
		})

		expected := pem.EncodeToMemory(&pem.Block{
			Type:    "CERTIFICATE",
			Headers: nil,
			Bytes:   c.pinnedCert.Raw,
		})

		return fmt.Errorf("Server's certificate didn't match expected! Server had\n%v\nbut expected:\n%v", received, expected)
	}
	return nil
}

// closes the session established by this client
// (and all multiplexed connections)
func (c *Client) Close() error {
	c.clearSession("client closed")
	return nil
}

func (c *Client) clearSession(reason string) {
	c.muSession.Lock()
	s := c.session
	c.session = nil
	c.muSession.Unlock()
	if s != nil {
		log.Debugf("Closing quic session (%v)", reason)
		s.CloseWithError(0, "")
	}
}