summaryrefslogtreecommitdiff
path: root/vendor/github.com/xtaci/smux/session.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/xtaci/smux/session.go')
-rw-r--r--vendor/github.com/xtaci/smux/session.go525
1 files changed, 0 insertions, 525 deletions
diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go
deleted file mode 100644
index bc56066..0000000
--- a/vendor/github.com/xtaci/smux/session.go
+++ /dev/null
@@ -1,525 +0,0 @@
-package smux
-
-import (
- "container/heap"
- "encoding/binary"
- "errors"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "time"
-)
-
-const (
- defaultAcceptBacklog = 1024
-)
-
-var (
- ErrInvalidProtocol = errors.New("invalid protocol")
- ErrConsumed = errors.New("peer consumed more than sent")
- ErrGoAway = errors.New("stream id overflows, should start a new connection")
- ErrTimeout = errors.New("timeout")
- ErrWouldBlock = errors.New("operation would block on IO")
-)
-
-type writeRequest struct {
- prio uint64
- frame Frame
- result chan writeResult
-}
-
-type writeResult struct {
- n int
- err error
-}
-
-type buffersWriter interface {
- WriteBuffers(v [][]byte) (n int, err error)
-}
-
-// Session defines a multiplexed connection for streams
-type Session struct {
- conn io.ReadWriteCloser
-
- config *Config
- nextStreamID uint32 // next stream identifier
- nextStreamIDLock sync.Mutex
-
- bucket int32 // token bucket
- bucketNotify chan struct{} // used for waiting for tokens
-
- streams map[uint32]*Stream // all streams in this session
- streamLock sync.Mutex // locks streams
-
- die chan struct{} // flag session has died
- dieOnce sync.Once
-
- // socket error handling
- socketReadError atomic.Value
- socketWriteError atomic.Value
- chSocketReadError chan struct{}
- chSocketWriteError chan struct{}
- socketReadErrorOnce sync.Once
- socketWriteErrorOnce sync.Once
-
- // smux protocol errors
- protoError atomic.Value
- chProtoError chan struct{}
- protoErrorOnce sync.Once
-
- chAccepts chan *Stream
-
- dataReady int32 // flag data has arrived
-
- goAway int32 // flag id exhausted
-
- deadline atomic.Value
-
- shaper chan writeRequest // a shaper for writing
- writes chan writeRequest
-}
-
-func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
- s := new(Session)
- s.die = make(chan struct{})
- s.conn = conn
- s.config = config
- s.streams = make(map[uint32]*Stream)
- s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
- s.bucket = int32(config.MaxReceiveBuffer)
- s.bucketNotify = make(chan struct{}, 1)
- s.shaper = make(chan writeRequest)
- s.writes = make(chan writeRequest)
- s.chSocketReadError = make(chan struct{})
- s.chSocketWriteError = make(chan struct{})
- s.chProtoError = make(chan struct{})
-
- if client {
- s.nextStreamID = 1
- } else {
- s.nextStreamID = 0
- }
-
- go s.shaperLoop()
- go s.recvLoop()
- go s.sendLoop()
- if !config.KeepAliveDisabled {
- go s.keepalive()
- }
- return s
-}
-
-// OpenStream is used to create a new stream
-func (s *Session) OpenStream() (*Stream, error) {
- if s.IsClosed() {
- return nil, io.ErrClosedPipe
- }
-
- // generate stream id
- s.nextStreamIDLock.Lock()
- if s.goAway > 0 {
- s.nextStreamIDLock.Unlock()
- return nil, ErrGoAway
- }
-
- s.nextStreamID += 2
- sid := s.nextStreamID
- if sid == sid%2 { // stream-id overflows
- s.goAway = 1
- s.nextStreamIDLock.Unlock()
- return nil, ErrGoAway
- }
- s.nextStreamIDLock.Unlock()
-
- stream := newStream(sid, s.config.MaxFrameSize, s)
-
- if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
- return nil, err
- }
-
- s.streamLock.Lock()
- defer s.streamLock.Unlock()
- select {
- case <-s.chSocketReadError:
- return nil, s.socketReadError.Load().(error)
- case <-s.chSocketWriteError:
- return nil, s.socketWriteError.Load().(error)
- case <-s.die:
- return nil, io.ErrClosedPipe
- default:
- s.streams[sid] = stream
- return stream, nil
- }
-}
-
-// Open returns a generic ReadWriteCloser
-func (s *Session) Open() (io.ReadWriteCloser, error) {
- return s.OpenStream()
-}
-
-// AcceptStream is used to block until the next available stream
-// is ready to be accepted.
-func (s *Session) AcceptStream() (*Stream, error) {
- var deadline <-chan time.Time
- if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(time.Until(d))
- defer timer.Stop()
- deadline = timer.C
- }
-
- select {
- case stream := <-s.chAccepts:
- return stream, nil
- case <-deadline:
- return nil, ErrTimeout
- case <-s.chSocketReadError:
- return nil, s.socketReadError.Load().(error)
- case <-s.chProtoError:
- return nil, s.protoError.Load().(error)
- case <-s.die:
- return nil, io.ErrClosedPipe
- }
-}
-
-// Accept Returns a generic ReadWriteCloser instead of smux.Stream
-func (s *Session) Accept() (io.ReadWriteCloser, error) {
- return s.AcceptStream()
-}
-
-// Close is used to close the session and all streams.
-func (s *Session) Close() error {
- var once bool
- s.dieOnce.Do(func() {
- close(s.die)
- once = true
- })
-
- if once {
- s.streamLock.Lock()
- for k := range s.streams {
- s.streams[k].sessionClose()
- }
- s.streamLock.Unlock()
- return s.conn.Close()
- } else {
- return io.ErrClosedPipe
- }
-}
-
-// notifyBucket notifies recvLoop that bucket is available
-func (s *Session) notifyBucket() {
- select {
- case s.bucketNotify <- struct{}{}:
- default:
- }
-}
-
-func (s *Session) notifyReadError(err error) {
- s.socketReadErrorOnce.Do(func() {
- s.socketReadError.Store(err)
- close(s.chSocketReadError)
- })
-}
-
-func (s *Session) notifyWriteError(err error) {
- s.socketWriteErrorOnce.Do(func() {
- s.socketWriteError.Store(err)
- close(s.chSocketWriteError)
- })
-}
-
-func (s *Session) notifyProtoError(err error) {
- s.protoErrorOnce.Do(func() {
- s.protoError.Store(err)
- close(s.chProtoError)
- })
-}
-
-// IsClosed does a safe check to see if we have shutdown
-func (s *Session) IsClosed() bool {
- select {
- case <-s.die:
- return true
- default:
- return false
- }
-}
-
-// NumStreams returns the number of currently open streams
-func (s *Session) NumStreams() int {
- if s.IsClosed() {
- return 0
- }
- s.streamLock.Lock()
- defer s.streamLock.Unlock()
- return len(s.streams)
-}
-
-// SetDeadline sets a deadline used by Accept* calls.
-// A zero time value disables the deadline.
-func (s *Session) SetDeadline(t time.Time) error {
- s.deadline.Store(t)
- return nil
-}
-
-// LocalAddr satisfies net.Conn interface
-func (s *Session) LocalAddr() net.Addr {
- if ts, ok := s.conn.(interface {
- LocalAddr() net.Addr
- }); ok {
- return ts.LocalAddr()
- }
- return nil
-}
-
-// RemoteAddr satisfies net.Conn interface
-func (s *Session) RemoteAddr() net.Addr {
- if ts, ok := s.conn.(interface {
- RemoteAddr() net.Addr
- }); ok {
- return ts.RemoteAddr()
- }
- return nil
-}
-
-// notify the session that a stream has closed
-func (s *Session) streamClosed(sid uint32) {
- s.streamLock.Lock()
- if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
- if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
- s.notifyBucket()
- }
- }
- delete(s.streams, sid)
- s.streamLock.Unlock()
-}
-
-// returnTokens is called by stream to return token after read
-func (s *Session) returnTokens(n int) {
- if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
- s.notifyBucket()
- }
-}
-
-// recvLoop keeps on reading from underlying connection if tokens are available
-func (s *Session) recvLoop() {
- var hdr rawHeader
- var updHdr updHeader
-
- for {
- for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
- select {
- case <-s.bucketNotify:
- case <-s.die:
- return
- }
- }
-
- // read header first
- if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
- atomic.StoreInt32(&s.dataReady, 1)
- if hdr.Version() != byte(s.config.Version) {
- s.notifyProtoError(ErrInvalidProtocol)
- return
- }
- sid := hdr.StreamID()
- switch hdr.Cmd() {
- case cmdNOP:
- case cmdSYN:
- s.streamLock.Lock()
- if _, ok := s.streams[sid]; !ok {
- stream := newStream(sid, s.config.MaxFrameSize, s)
- s.streams[sid] = stream
- select {
- case s.chAccepts <- stream:
- case <-s.die:
- }
- }
- s.streamLock.Unlock()
- case cmdFIN:
- s.streamLock.Lock()
- if stream, ok := s.streams[sid]; ok {
- stream.fin()
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- case cmdPSH:
- if hdr.Length() > 0 {
- newbuf := defaultAllocator.Get(int(hdr.Length()))
- if written, err := io.ReadFull(s.conn, newbuf); err == nil {
- s.streamLock.Lock()
- if stream, ok := s.streams[sid]; ok {
- stream.pushBytes(newbuf)
- atomic.AddInt32(&s.bucket, -int32(written))
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- } else {
- s.notifyReadError(err)
- return
- }
- }
- case cmdUPD:
- if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
- s.streamLock.Lock()
- if stream, ok := s.streams[sid]; ok {
- stream.update(updHdr.Consumed(), updHdr.Window())
- }
- s.streamLock.Unlock()
- } else {
- s.notifyReadError(err)
- return
- }
- default:
- s.notifyProtoError(ErrInvalidProtocol)
- return
- }
- } else {
- s.notifyReadError(err)
- return
- }
- }
-}
-
-func (s *Session) keepalive() {
- tickerPing := time.NewTicker(s.config.KeepAliveInterval)
- tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
- defer tickerPing.Stop()
- defer tickerTimeout.Stop()
- for {
- select {
- case <-tickerPing.C:
- s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
- s.notifyBucket() // force a signal to the recvLoop
- case <-tickerTimeout.C:
- if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
- // recvLoop may block while bucket is 0, in this case,
- // session should not be closed.
- if atomic.LoadInt32(&s.bucket) > 0 {
- s.Close()
- return
- }
- }
- case <-s.die:
- return
- }
- }
-}
-
-// shaper shapes the sending sequence among streams
-func (s *Session) shaperLoop() {
- var reqs shaperHeap
- var next writeRequest
- var chWrite chan writeRequest
-
- for {
- if len(reqs) > 0 {
- chWrite = s.writes
- next = heap.Pop(&reqs).(writeRequest)
- } else {
- chWrite = nil
- }
-
- select {
- case <-s.die:
- return
- case r := <-s.shaper:
- if chWrite != nil { // next is valid, reshape
- heap.Push(&reqs, next)
- }
- heap.Push(&reqs, r)
- case chWrite <- next:
- }
- }
-}
-
-func (s *Session) sendLoop() {
- var buf []byte
- var n int
- var err error
- var vec [][]byte // vector for writeBuffers
-
- bw, ok := s.conn.(buffersWriter)
- if ok {
- buf = make([]byte, headerSize)
- vec = make([][]byte, 2)
- } else {
- buf = make([]byte, (1<<16)+headerSize)
- }
-
- for {
- select {
- case <-s.die:
- return
- case request := <-s.writes:
- buf[0] = request.frame.ver
- buf[1] = request.frame.cmd
- binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
- binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
-
- if len(vec) > 0 {
- vec[0] = buf[:headerSize]
- vec[1] = request.frame.data
- n, err = bw.WriteBuffers(vec)
- } else {
- copy(buf[headerSize:], request.frame.data)
- n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
- }
-
- n -= headerSize
- if n < 0 {
- n = 0
- }
-
- result := writeResult{
- n: n,
- err: err,
- }
-
- request.result <- result
- close(request.result)
-
- // store conn error
- if err != nil {
- s.notifyWriteError(err)
- return
- }
- }
- }
-}
-
-// writeFrame writes the frame to the underlying connection
-// and returns the number of bytes written if successful
-func (s *Session) writeFrame(f Frame) (n int, err error) {
- return s.writeFrameInternal(f, nil, 0)
-}
-
-// internal writeFrame version to support deadline used in keepalive
-func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
- req := writeRequest{
- prio: prio,
- frame: f,
- result: make(chan writeResult, 1),
- }
- select {
- case s.shaper <- req:
- case <-s.die:
- return 0, io.ErrClosedPipe
- case <-s.chSocketWriteError:
- return 0, s.socketWriteError.Load().(error)
- case <-deadline:
- return 0, ErrTimeout
- }
-
- select {
- case result := <-req.result:
- return result.n, result.err
- case <-s.die:
- return 0, io.ErrClosedPipe
- case <-s.chSocketWriteError:
- return 0, s.socketWriteError.Load().(error)
- case <-deadline:
- return 0, ErrTimeout
- }
-}