summaryrefslogtreecommitdiff
path: root/vendor/github.com/xtaci/smux/session.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/xtaci/smux/session.go')
-rw-r--r--vendor/github.com/xtaci/smux/session.go525
1 files changed, 525 insertions, 0 deletions
diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go
new file mode 100644
index 0000000..bc56066
--- /dev/null
+++ b/vendor/github.com/xtaci/smux/session.go
@@ -0,0 +1,525 @@
+package smux
+
+import (
+ "container/heap"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const (
+ defaultAcceptBacklog = 1024
+)
+
+var (
+ ErrInvalidProtocol = errors.New("invalid protocol")
+ ErrConsumed = errors.New("peer consumed more than sent")
+ ErrGoAway = errors.New("stream id overflows, should start a new connection")
+ ErrTimeout = errors.New("timeout")
+ ErrWouldBlock = errors.New("operation would block on IO")
+)
+
+type writeRequest struct {
+ prio uint64
+ frame Frame
+ result chan writeResult
+}
+
+type writeResult struct {
+ n int
+ err error
+}
+
+type buffersWriter interface {
+ WriteBuffers(v [][]byte) (n int, err error)
+}
+
+// Session defines a multiplexed connection for streams
+type Session struct {
+ conn io.ReadWriteCloser
+
+ config *Config
+ nextStreamID uint32 // next stream identifier
+ nextStreamIDLock sync.Mutex
+
+ bucket int32 // token bucket
+ bucketNotify chan struct{} // used for waiting for tokens
+
+ streams map[uint32]*Stream // all streams in this session
+ streamLock sync.Mutex // locks streams
+
+ die chan struct{} // flag session has died
+ dieOnce sync.Once
+
+ // socket error handling
+ socketReadError atomic.Value
+ socketWriteError atomic.Value
+ chSocketReadError chan struct{}
+ chSocketWriteError chan struct{}
+ socketReadErrorOnce sync.Once
+ socketWriteErrorOnce sync.Once
+
+ // smux protocol errors
+ protoError atomic.Value
+ chProtoError chan struct{}
+ protoErrorOnce sync.Once
+
+ chAccepts chan *Stream
+
+ dataReady int32 // flag data has arrived
+
+ goAway int32 // flag id exhausted
+
+ deadline atomic.Value
+
+ shaper chan writeRequest // a shaper for writing
+ writes chan writeRequest
+}
+
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+ s := new(Session)
+ s.die = make(chan struct{})
+ s.conn = conn
+ s.config = config
+ s.streams = make(map[uint32]*Stream)
+ s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
+ s.bucket = int32(config.MaxReceiveBuffer)
+ s.bucketNotify = make(chan struct{}, 1)
+ s.shaper = make(chan writeRequest)
+ s.writes = make(chan writeRequest)
+ s.chSocketReadError = make(chan struct{})
+ s.chSocketWriteError = make(chan struct{})
+ s.chProtoError = make(chan struct{})
+
+ if client {
+ s.nextStreamID = 1
+ } else {
+ s.nextStreamID = 0
+ }
+
+ go s.shaperLoop()
+ go s.recvLoop()
+ go s.sendLoop()
+ if !config.KeepAliveDisabled {
+ go s.keepalive()
+ }
+ return s
+}
+
+// OpenStream is used to create a new stream
+func (s *Session) OpenStream() (*Stream, error) {
+ if s.IsClosed() {
+ return nil, io.ErrClosedPipe
+ }
+
+ // generate stream id
+ s.nextStreamIDLock.Lock()
+ if s.goAway > 0 {
+ s.nextStreamIDLock.Unlock()
+ return nil, ErrGoAway
+ }
+
+ s.nextStreamID += 2
+ sid := s.nextStreamID
+ if sid == sid%2 { // stream-id overflows
+ s.goAway = 1
+ s.nextStreamIDLock.Unlock()
+ return nil, ErrGoAway
+ }
+ s.nextStreamIDLock.Unlock()
+
+ stream := newStream(sid, s.config.MaxFrameSize, s)
+
+ if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
+ return nil, err
+ }
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ select {
+ case <-s.chSocketReadError:
+ return nil, s.socketReadError.Load().(error)
+ case <-s.chSocketWriteError:
+ return nil, s.socketWriteError.Load().(error)
+ case <-s.die:
+ return nil, io.ErrClosedPipe
+ default:
+ s.streams[sid] = stream
+ return stream, nil
+ }
+}
+
+// Open returns a generic ReadWriteCloser
+func (s *Session) Open() (io.ReadWriteCloser, error) {
+ return s.OpenStream()
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+ var deadline <-chan time.Time
+ if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ select {
+ case stream := <-s.chAccepts:
+ return stream, nil
+ case <-deadline:
+ return nil, ErrTimeout
+ case <-s.chSocketReadError:
+ return nil, s.socketReadError.Load().(error)
+ case <-s.chProtoError:
+ return nil, s.protoError.Load().(error)
+ case <-s.die:
+ return nil, io.ErrClosedPipe
+ }
+}
+
+// Accept Returns a generic ReadWriteCloser instead of smux.Stream
+func (s *Session) Accept() (io.ReadWriteCloser, error) {
+ return s.AcceptStream()
+}
+
+// Close is used to close the session and all streams.
+func (s *Session) Close() error {
+ var once bool
+ s.dieOnce.Do(func() {
+ close(s.die)
+ once = true
+ })
+
+ if once {
+ s.streamLock.Lock()
+ for k := range s.streams {
+ s.streams[k].sessionClose()
+ }
+ s.streamLock.Unlock()
+ return s.conn.Close()
+ } else {
+ return io.ErrClosedPipe
+ }
+}
+
+// notifyBucket notifies recvLoop that bucket is available
+func (s *Session) notifyBucket() {
+ select {
+ case s.bucketNotify <- struct{}{}:
+ default:
+ }
+}
+
+func (s *Session) notifyReadError(err error) {
+ s.socketReadErrorOnce.Do(func() {
+ s.socketReadError.Store(err)
+ close(s.chSocketReadError)
+ })
+}
+
+func (s *Session) notifyWriteError(err error) {
+ s.socketWriteErrorOnce.Do(func() {
+ s.socketWriteError.Store(err)
+ close(s.chSocketWriteError)
+ })
+}
+
+func (s *Session) notifyProtoError(err error) {
+ s.protoErrorOnce.Do(func() {
+ s.protoError.Store(err)
+ close(s.chProtoError)
+ })
+}
+
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
+ select {
+ case <-s.die:
+ return true
+ default:
+ return false
+ }
+}
+
+// NumStreams returns the number of currently open streams
+func (s *Session) NumStreams() int {
+ if s.IsClosed() {
+ return 0
+ }
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ return len(s.streams)
+}
+
+// SetDeadline sets a deadline used by Accept* calls.
+// A zero time value disables the deadline.
+func (s *Session) SetDeadline(t time.Time) error {
+ s.deadline.Store(t)
+ return nil
+}
+
+// LocalAddr satisfies net.Conn interface
+func (s *Session) LocalAddr() net.Addr {
+ if ts, ok := s.conn.(interface {
+ LocalAddr() net.Addr
+ }); ok {
+ return ts.LocalAddr()
+ }
+ return nil
+}
+
+// RemoteAddr satisfies net.Conn interface
+func (s *Session) RemoteAddr() net.Addr {
+ if ts, ok := s.conn.(interface {
+ RemoteAddr() net.Addr
+ }); ok {
+ return ts.RemoteAddr()
+ }
+ return nil
+}
+
+// notify the session that a stream has closed
+func (s *Session) streamClosed(sid uint32) {
+ s.streamLock.Lock()
+ if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
+ if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
+ s.notifyBucket()
+ }
+ }
+ delete(s.streams, sid)
+ s.streamLock.Unlock()
+}
+
+// returnTokens is called by stream to return token after read
+func (s *Session) returnTokens(n int) {
+ if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
+ s.notifyBucket()
+ }
+}
+
+// recvLoop keeps on reading from underlying connection if tokens are available
+func (s *Session) recvLoop() {
+ var hdr rawHeader
+ var updHdr updHeader
+
+ for {
+ for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
+ select {
+ case <-s.bucketNotify:
+ case <-s.die:
+ return
+ }
+ }
+
+ // read header first
+ if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
+ atomic.StoreInt32(&s.dataReady, 1)
+ if hdr.Version() != byte(s.config.Version) {
+ s.notifyProtoError(ErrInvalidProtocol)
+ return
+ }
+ sid := hdr.StreamID()
+ switch hdr.Cmd() {
+ case cmdNOP:
+ case cmdSYN:
+ s.streamLock.Lock()
+ if _, ok := s.streams[sid]; !ok {
+ stream := newStream(sid, s.config.MaxFrameSize, s)
+ s.streams[sid] = stream
+ select {
+ case s.chAccepts <- stream:
+ case <-s.die:
+ }
+ }
+ s.streamLock.Unlock()
+ case cmdFIN:
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.fin()
+ stream.notifyReadEvent()
+ }
+ s.streamLock.Unlock()
+ case cmdPSH:
+ if hdr.Length() > 0 {
+ newbuf := defaultAllocator.Get(int(hdr.Length()))
+ if written, err := io.ReadFull(s.conn, newbuf); err == nil {
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.pushBytes(newbuf)
+ atomic.AddInt32(&s.bucket, -int32(written))
+ stream.notifyReadEvent()
+ }
+ s.streamLock.Unlock()
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ }
+ case cmdUPD:
+ if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.update(updHdr.Consumed(), updHdr.Window())
+ }
+ s.streamLock.Unlock()
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ default:
+ s.notifyProtoError(ErrInvalidProtocol)
+ return
+ }
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ }
+}
+
+func (s *Session) keepalive() {
+ tickerPing := time.NewTicker(s.config.KeepAliveInterval)
+ tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
+ defer tickerPing.Stop()
+ defer tickerTimeout.Stop()
+ for {
+ select {
+ case <-tickerPing.C:
+ s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
+ s.notifyBucket() // force a signal to the recvLoop
+ case <-tickerTimeout.C:
+ if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
+ // recvLoop may block while bucket is 0, in this case,
+ // session should not be closed.
+ if atomic.LoadInt32(&s.bucket) > 0 {
+ s.Close()
+ return
+ }
+ }
+ case <-s.die:
+ return
+ }
+ }
+}
+
+// shaper shapes the sending sequence among streams
+func (s *Session) shaperLoop() {
+ var reqs shaperHeap
+ var next writeRequest
+ var chWrite chan writeRequest
+
+ for {
+ if len(reqs) > 0 {
+ chWrite = s.writes
+ next = heap.Pop(&reqs).(writeRequest)
+ } else {
+ chWrite = nil
+ }
+
+ select {
+ case <-s.die:
+ return
+ case r := <-s.shaper:
+ if chWrite != nil { // next is valid, reshape
+ heap.Push(&reqs, next)
+ }
+ heap.Push(&reqs, r)
+ case chWrite <- next:
+ }
+ }
+}
+
+func (s *Session) sendLoop() {
+ var buf []byte
+ var n int
+ var err error
+ var vec [][]byte // vector for writeBuffers
+
+ bw, ok := s.conn.(buffersWriter)
+ if ok {
+ buf = make([]byte, headerSize)
+ vec = make([][]byte, 2)
+ } else {
+ buf = make([]byte, (1<<16)+headerSize)
+ }
+
+ for {
+ select {
+ case <-s.die:
+ return
+ case request := <-s.writes:
+ buf[0] = request.frame.ver
+ buf[1] = request.frame.cmd
+ binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
+ binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
+
+ if len(vec) > 0 {
+ vec[0] = buf[:headerSize]
+ vec[1] = request.frame.data
+ n, err = bw.WriteBuffers(vec)
+ } else {
+ copy(buf[headerSize:], request.frame.data)
+ n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
+ }
+
+ n -= headerSize
+ if n < 0 {
+ n = 0
+ }
+
+ result := writeResult{
+ n: n,
+ err: err,
+ }
+
+ request.result <- result
+ close(request.result)
+
+ // store conn error
+ if err != nil {
+ s.notifyWriteError(err)
+ return
+ }
+ }
+ }
+}
+
+// writeFrame writes the frame to the underlying connection
+// and returns the number of bytes written if successful
+func (s *Session) writeFrame(f Frame) (n int, err error) {
+ return s.writeFrameInternal(f, nil, 0)
+}
+
+// internal writeFrame version to support deadline used in keepalive
+func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
+ req := writeRequest{
+ prio: prio,
+ frame: f,
+ result: make(chan writeResult, 1),
+ }
+ select {
+ case s.shaper <- req:
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ case <-s.chSocketWriteError:
+ return 0, s.socketWriteError.Load().(error)
+ case <-deadline:
+ return 0, ErrTimeout
+ }
+
+ select {
+ case result := <-req.result:
+ return result.n, result.err
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ case <-s.chSocketWriteError:
+ return 0, s.socketWriteError.Load().(error)
+ case <-deadline:
+ return 0, ErrTimeout
+ }
+}