diff options
Diffstat (limited to 'vendor/github.com/xtaci/smux/session.go')
-rw-r--r-- | vendor/github.com/xtaci/smux/session.go | 525 |
1 files changed, 525 insertions, 0 deletions
diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go new file mode 100644 index 0000000..bc56066 --- /dev/null +++ b/vendor/github.com/xtaci/smux/session.go @@ -0,0 +1,525 @@ +package smux + +import ( + "container/heap" + "encoding/binary" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + defaultAcceptBacklog = 1024 +) + +var ( + ErrInvalidProtocol = errors.New("invalid protocol") + ErrConsumed = errors.New("peer consumed more than sent") + ErrGoAway = errors.New("stream id overflows, should start a new connection") + ErrTimeout = errors.New("timeout") + ErrWouldBlock = errors.New("operation would block on IO") +) + +type writeRequest struct { + prio uint64 + frame Frame + result chan writeResult +} + +type writeResult struct { + n int + err error +} + +type buffersWriter interface { + WriteBuffers(v [][]byte) (n int, err error) +} + +// Session defines a multiplexed connection for streams +type Session struct { + conn io.ReadWriteCloser + + config *Config + nextStreamID uint32 // next stream identifier + nextStreamIDLock sync.Mutex + + bucket int32 // token bucket + bucketNotify chan struct{} // used for waiting for tokens + + streams map[uint32]*Stream // all streams in this session + streamLock sync.Mutex // locks streams + + die chan struct{} // flag session has died + dieOnce sync.Once + + // socket error handling + socketReadError atomic.Value + socketWriteError atomic.Value + chSocketReadError chan struct{} + chSocketWriteError chan struct{} + socketReadErrorOnce sync.Once + socketWriteErrorOnce sync.Once + + // smux protocol errors + protoError atomic.Value + chProtoError chan struct{} + protoErrorOnce sync.Once + + chAccepts chan *Stream + + dataReady int32 // flag data has arrived + + goAway int32 // flag id exhausted + + deadline atomic.Value + + shaper chan writeRequest // a shaper for writing + writes chan writeRequest +} + +func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { + s := new(Session) + s.die = make(chan struct{}) + s.conn = conn + s.config = config + s.streams = make(map[uint32]*Stream) + s.chAccepts = make(chan *Stream, defaultAcceptBacklog) + s.bucket = int32(config.MaxReceiveBuffer) + s.bucketNotify = make(chan struct{}, 1) + s.shaper = make(chan writeRequest) + s.writes = make(chan writeRequest) + s.chSocketReadError = make(chan struct{}) + s.chSocketWriteError = make(chan struct{}) + s.chProtoError = make(chan struct{}) + + if client { + s.nextStreamID = 1 + } else { + s.nextStreamID = 0 + } + + go s.shaperLoop() + go s.recvLoop() + go s.sendLoop() + if !config.KeepAliveDisabled { + go s.keepalive() + } + return s +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + // generate stream id + s.nextStreamIDLock.Lock() + if s.goAway > 0 { + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + + s.nextStreamID += 2 + sid := s.nextStreamID + if sid == sid%2 { // stream-id overflows + s.goAway = 1 + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + s.nextStreamIDLock.Unlock() + + stream := newStream(sid, s.config.MaxFrameSize, s) + + if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { + return nil, err + } + + s.streamLock.Lock() + defer s.streamLock.Unlock() + select { + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chSocketWriteError: + return nil, s.socketWriteError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + default: + s.streams[sid] = stream + return stream, nil + } +} + +// Open returns a generic ReadWriteCloser +func (s *Session) Open() (io.ReadWriteCloser, error) { + return s.OpenStream() +} + +// AcceptStream is used to block until the next available stream +// is ready to be accepted. +func (s *Session) AcceptStream() (*Stream, error) { + var deadline <-chan time.Time + if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case stream := <-s.chAccepts: + return stream, nil + case <-deadline: + return nil, ErrTimeout + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chProtoError: + return nil, s.protoError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + } +} + +// Accept Returns a generic ReadWriteCloser instead of smux.Stream +func (s *Session) Accept() (io.ReadWriteCloser, error) { + return s.AcceptStream() +} + +// Close is used to close the session and all streams. +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + s.streamLock.Lock() + for k := range s.streams { + s.streams[k].sessionClose() + } + s.streamLock.Unlock() + return s.conn.Close() + } else { + return io.ErrClosedPipe + } +} + +// notifyBucket notifies recvLoop that bucket is available +func (s *Session) notifyBucket() { + select { + case s.bucketNotify <- struct{}{}: + default: + } +} + +func (s *Session) notifyReadError(err error) { + s.socketReadErrorOnce.Do(func() { + s.socketReadError.Store(err) + close(s.chSocketReadError) + }) +} + +func (s *Session) notifyWriteError(err error) { + s.socketWriteErrorOnce.Do(func() { + s.socketWriteError.Store(err) + close(s.chSocketWriteError) + }) +} + +func (s *Session) notifyProtoError(err error) { + s.protoErrorOnce.Do(func() { + s.protoError.Store(err) + close(s.chProtoError) + }) +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +// NumStreams returns the number of currently open streams +func (s *Session) NumStreams() int { + if s.IsClosed() { + return 0 + } + s.streamLock.Lock() + defer s.streamLock.Unlock() + return len(s.streams) +} + +// SetDeadline sets a deadline used by Accept* calls. +// A zero time value disables the deadline. +func (s *Session) SetDeadline(t time.Time) error { + s.deadline.Store(t) + return nil +} + +// LocalAddr satisfies net.Conn interface +func (s *Session) LocalAddr() net.Addr { + if ts, ok := s.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Session) RemoteAddr() net.Addr { + if ts, ok := s.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} + +// notify the session that a stream has closed +func (s *Session) streamClosed(sid uint32) { + s.streamLock.Lock() + if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } + } + delete(s.streams, sid) + s.streamLock.Unlock() +} + +// returnTokens is called by stream to return token after read +func (s *Session) returnTokens(n int) { + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } +} + +// recvLoop keeps on reading from underlying connection if tokens are available +func (s *Session) recvLoop() { + var hdr rawHeader + var updHdr updHeader + + for { + for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { + select { + case <-s.bucketNotify: + case <-s.die: + return + } + } + + // read header first + if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { + atomic.StoreInt32(&s.dataReady, 1) + if hdr.Version() != byte(s.config.Version) { + s.notifyProtoError(ErrInvalidProtocol) + return + } + sid := hdr.StreamID() + switch hdr.Cmd() { + case cmdNOP: + case cmdSYN: + s.streamLock.Lock() + if _, ok := s.streams[sid]; !ok { + stream := newStream(sid, s.config.MaxFrameSize, s) + s.streams[sid] = stream + select { + case s.chAccepts <- stream: + case <-s.die: + } + } + s.streamLock.Unlock() + case cmdFIN: + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.fin() + stream.notifyReadEvent() + } + s.streamLock.Unlock() + case cmdPSH: + if hdr.Length() > 0 { + newbuf := defaultAllocator.Get(int(hdr.Length())) + if written, err := io.ReadFull(s.conn, newbuf); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.pushBytes(newbuf) + atomic.AddInt32(&s.bucket, -int32(written)) + stream.notifyReadEvent() + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + } + case cmdUPD: + if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.update(updHdr.Consumed(), updHdr.Window()) + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + default: + s.notifyProtoError(ErrInvalidProtocol) + return + } + } else { + s.notifyReadError(err) + return + } + } +} + +func (s *Session) keepalive() { + tickerPing := time.NewTicker(s.config.KeepAliveInterval) + tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) + defer tickerPing.Stop() + defer tickerTimeout.Stop() + for { + select { + case <-tickerPing.C: + s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0) + s.notifyBucket() // force a signal to the recvLoop + case <-tickerTimeout.C: + if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { + // recvLoop may block while bucket is 0, in this case, + // session should not be closed. + if atomic.LoadInt32(&s.bucket) > 0 { + s.Close() + return + } + } + case <-s.die: + return + } + } +} + +// shaper shapes the sending sequence among streams +func (s *Session) shaperLoop() { + var reqs shaperHeap + var next writeRequest + var chWrite chan writeRequest + + for { + if len(reqs) > 0 { + chWrite = s.writes + next = heap.Pop(&reqs).(writeRequest) + } else { + chWrite = nil + } + + select { + case <-s.die: + return + case r := <-s.shaper: + if chWrite != nil { // next is valid, reshape + heap.Push(&reqs, next) + } + heap.Push(&reqs, r) + case chWrite <- next: + } + } +} + +func (s *Session) sendLoop() { + var buf []byte + var n int + var err error + var vec [][]byte // vector for writeBuffers + + bw, ok := s.conn.(buffersWriter) + if ok { + buf = make([]byte, headerSize) + vec = make([][]byte, 2) + } else { + buf = make([]byte, (1<<16)+headerSize) + } + + for { + select { + case <-s.die: + return + case request := <-s.writes: + buf[0] = request.frame.ver + buf[1] = request.frame.cmd + binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) + binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) + + if len(vec) > 0 { + vec[0] = buf[:headerSize] + vec[1] = request.frame.data + n, err = bw.WriteBuffers(vec) + } else { + copy(buf[headerSize:], request.frame.data) + n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) + } + + n -= headerSize + if n < 0 { + n = 0 + } + + result := writeResult{ + n: n, + err: err, + } + + request.result <- result + close(request.result) + + // store conn error + if err != nil { + s.notifyWriteError(err) + return + } + } + } +} + +// writeFrame writes the frame to the underlying connection +// and returns the number of bytes written if successful +func (s *Session) writeFrame(f Frame) (n int, err error) { + return s.writeFrameInternal(f, nil, 0) +} + +// internal writeFrame version to support deadline used in keepalive +func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { + req := writeRequest{ + prio: prio, + frame: f, + result: make(chan writeResult, 1), + } + select { + case s.shaper <- req: + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } + + select { + case result := <-req.result: + return result.n, result.err + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } +} |