diff options
Diffstat (limited to 'quicwrapper/listener.go')
-rw-r--r-- | quicwrapper/listener.go | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/quicwrapper/listener.go b/quicwrapper/listener.go new file mode 100644 index 0000000..de28d2a --- /dev/null +++ b/quicwrapper/listener.go @@ -0,0 +1,206 @@ +package quicwrapper + +import ( + "context" + "crypto/tls" + "net" + "sync" + "sync/atomic" + + "github.com/apex/log" + "github.com/getlantern/ops" + quic "github.com/lucas-clemente/quic-go" +) + +// ListenAddr creates a QUIC server listening on a given address. +// The net.Conn instances returned by the net.Listener may be multiplexed connections. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (net.Listener, error) { + tlsConf = defaultNextProtos(tlsConf, DefaultServerProtos) + ql, err := quic.ListenAddr(addr, tlsConf, config) + if err != nil { + return nil, err + } + return listen(ql, tlsConf, config) +} + +// Listen creates a QUIC server listening on a given net.PacketConn +// The net.Conn instances returned by the net.Listener may be multiplexed connections. +// The caller is responsible for closing the net.PacketConn after the listener has been +// closed. +func Listen(pconn net.PacketConn, tlsConf *tls.Config, config *Config) (net.Listener, error) { + tlsConf = defaultNextProtos(tlsConf, DefaultServerProtos) + ql, err := quic.Listen(pconn, tlsConf, config) + if err != nil { + return nil, err + } + return listen(ql, tlsConf, config) +} + +func listen(ql quic.Listener, tlsConf *tls.Config, config *Config) (net.Listener, error) { + l := &listener{ + quicListener: ql, + config: config, + connections: make(chan net.Conn, 1000), + acceptError: make(chan error, 1), + closedSignal: make(chan struct{}), + } + // XXX wat is this? + ops.Go(l.listen) + //ops.Go(l.logStats) + + return l, nil +} + +var _ net.Listener = &listener{} + +// wraps quic.Listener to create a net.Listener +type listener struct { + numConnections int64 + numVirtualConnections int64 + quicListener quic.Listener + config *Config + connections chan net.Conn + acceptError chan error + closedSignal chan struct{} + closeErr error + closeOnce sync.Once +} + +// implements net.Listener.Accept +func (l *listener) Accept() (net.Conn, error) { + select { + case conn, ok := <-l.connections: + if !ok { + return nil, ErrListenerClosed + } + return conn, nil + case err, ok := <-l.acceptError: + if !ok { + return nil, ErrListenerClosed + } + return nil, err + case <-l.closedSignal: + return nil, ErrListenerClosed + } +} + +// implements net.Listener.Close +// Shut down the QUIC listener. +// this implicitly sends CONNECTION_CLOSE frames to peers +// note: it is still the responsibility of the caller +// to call Close() on any Conn returned from Accept() +func (l *listener) Close() error { + l.closeOnce.Do(func() { + close(l.closedSignal) + l.closeErr = l.quicListener.Close() + }) + return l.closeErr +} + +func (l *listener) isClosed() bool { + select { + case <-l.closedSignal: + return true + default: + return false + } +} + +// implements net.Listener.Addr +func (l *listener) Addr() net.Addr { + return l.quicListener.Addr() +} + +func (l *listener) listen() { + group := &sync.WaitGroup{} + + defer func() { + l.Close() + close(l.acceptError) + // wait for writers to exit, drain connections + group.Wait() + close(l.connections) + for c := range l.connections { + c.Close() + } + + log.Debugf("Listener finished with Connections: %d Virtual: %d", atomic.LoadInt64(&l.numConnections), atomic.LoadInt64(&l.numVirtualConnections)) + }() + + for { + session, err := l.quicListener.Accept(context.Background()) + if err != nil { + if !l.isClosed() { + l.acceptError <- err + } + return + } + if l.isClosed() { + session.CloseWithError(0, "") + return + } else { + atomic.AddInt64(&l.numConnections, 1) + group.Add(1) + ops.Go(func() { + l.handleSession(session) + atomic.AddInt64(&l.numConnections, -1) + group.Done() + }) + } + } +} + +func (l *listener) handleSession(session quic.Connection) { + + // keep a smoothed average of the bandwidth estimate + // for the session + // bw := NewEMABandwidthSampler(session) + // bw.Start() + + // track active session connections + active := make(map[quic.StreamID]Conn) + var mx sync.Mutex + + defer func() { + // bw.Stop() + session.CloseWithError(0, "") + + // snapshot any non-closed connections, then nil out active list + // conns being closed will 'remove' themselves from the nil + // list, not the snapshot. + var snapshot map[quic.StreamID]Conn + mx.Lock() + snapshot = active + active = nil + mx.Unlock() + + // immediately close any connections that are still active + for _, conn := range snapshot { + conn.Close() + } + }() + + for { + stream, err := session.AcceptStream(context.Background()) + if err != nil { + if isPeerGoingAway(err) { + // log.Tracef("Accepting stream: Peer going away (%v)", err) + return + } else { + // log.Errorf("Accepting stream: %v", err) + return + } + } else { + atomic.AddInt64(&l.numVirtualConnections, 1) + conn := newConn(stream, session, func(id quic.StreamID) { + atomic.AddInt64(&l.numVirtualConnections, -1) + // remove conn from active list + mx.Lock() + delete(active, id) + mx.Unlock() + }) + + l.connections <- conn + } + } +} |