diff options
Diffstat (limited to 'vendor/github.com/xtaci/smux/session.go')
-rw-r--r-- | vendor/github.com/xtaci/smux/session.go | 525 |
1 files changed, 0 insertions, 525 deletions
diff --git a/vendor/github.com/xtaci/smux/session.go b/vendor/github.com/xtaci/smux/session.go deleted file mode 100644 index bc56066..0000000 --- a/vendor/github.com/xtaci/smux/session.go +++ /dev/null @@ -1,525 +0,0 @@ -package smux - -import ( - "container/heap" - "encoding/binary" - "errors" - "io" - "net" - "sync" - "sync/atomic" - "time" -) - -const ( - defaultAcceptBacklog = 1024 -) - -var ( - ErrInvalidProtocol = errors.New("invalid protocol") - ErrConsumed = errors.New("peer consumed more than sent") - ErrGoAway = errors.New("stream id overflows, should start a new connection") - ErrTimeout = errors.New("timeout") - ErrWouldBlock = errors.New("operation would block on IO") -) - -type writeRequest struct { - prio uint64 - frame Frame - result chan writeResult -} - -type writeResult struct { - n int - err error -} - -type buffersWriter interface { - WriteBuffers(v [][]byte) (n int, err error) -} - -// Session defines a multiplexed connection for streams -type Session struct { - conn io.ReadWriteCloser - - config *Config - nextStreamID uint32 // next stream identifier - nextStreamIDLock sync.Mutex - - bucket int32 // token bucket - bucketNotify chan struct{} // used for waiting for tokens - - streams map[uint32]*Stream // all streams in this session - streamLock sync.Mutex // locks streams - - die chan struct{} // flag session has died - dieOnce sync.Once - - // socket error handling - socketReadError atomic.Value - socketWriteError atomic.Value - chSocketReadError chan struct{} - chSocketWriteError chan struct{} - socketReadErrorOnce sync.Once - socketWriteErrorOnce sync.Once - - // smux protocol errors - protoError atomic.Value - chProtoError chan struct{} - protoErrorOnce sync.Once - - chAccepts chan *Stream - - dataReady int32 // flag data has arrived - - goAway int32 // flag id exhausted - - deadline atomic.Value - - shaper chan writeRequest // a shaper for writing - writes chan writeRequest -} - -func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { - s := new(Session) - s.die = make(chan struct{}) - s.conn = conn - s.config = config - s.streams = make(map[uint32]*Stream) - s.chAccepts = make(chan *Stream, defaultAcceptBacklog) - s.bucket = int32(config.MaxReceiveBuffer) - s.bucketNotify = make(chan struct{}, 1) - s.shaper = make(chan writeRequest) - s.writes = make(chan writeRequest) - s.chSocketReadError = make(chan struct{}) - s.chSocketWriteError = make(chan struct{}) - s.chProtoError = make(chan struct{}) - - if client { - s.nextStreamID = 1 - } else { - s.nextStreamID = 0 - } - - go s.shaperLoop() - go s.recvLoop() - go s.sendLoop() - if !config.KeepAliveDisabled { - go s.keepalive() - } - return s -} - -// OpenStream is used to create a new stream -func (s *Session) OpenStream() (*Stream, error) { - if s.IsClosed() { - return nil, io.ErrClosedPipe - } - - // generate stream id - s.nextStreamIDLock.Lock() - if s.goAway > 0 { - s.nextStreamIDLock.Unlock() - return nil, ErrGoAway - } - - s.nextStreamID += 2 - sid := s.nextStreamID - if sid == sid%2 { // stream-id overflows - s.goAway = 1 - s.nextStreamIDLock.Unlock() - return nil, ErrGoAway - } - s.nextStreamIDLock.Unlock() - - stream := newStream(sid, s.config.MaxFrameSize, s) - - if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { - return nil, err - } - - s.streamLock.Lock() - defer s.streamLock.Unlock() - select { - case <-s.chSocketReadError: - return nil, s.socketReadError.Load().(error) - case <-s.chSocketWriteError: - return nil, s.socketWriteError.Load().(error) - case <-s.die: - return nil, io.ErrClosedPipe - default: - s.streams[sid] = stream - return stream, nil - } -} - -// Open returns a generic ReadWriteCloser -func (s *Session) Open() (io.ReadWriteCloser, error) { - return s.OpenStream() -} - -// AcceptStream is used to block until the next available stream -// is ready to be accepted. -func (s *Session) AcceptStream() (*Stream, error) { - var deadline <-chan time.Time - if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { - timer := time.NewTimer(time.Until(d)) - defer timer.Stop() - deadline = timer.C - } - - select { - case stream := <-s.chAccepts: - return stream, nil - case <-deadline: - return nil, ErrTimeout - case <-s.chSocketReadError: - return nil, s.socketReadError.Load().(error) - case <-s.chProtoError: - return nil, s.protoError.Load().(error) - case <-s.die: - return nil, io.ErrClosedPipe - } -} - -// Accept Returns a generic ReadWriteCloser instead of smux.Stream -func (s *Session) Accept() (io.ReadWriteCloser, error) { - return s.AcceptStream() -} - -// Close is used to close the session and all streams. -func (s *Session) Close() error { - var once bool - s.dieOnce.Do(func() { - close(s.die) - once = true - }) - - if once { - s.streamLock.Lock() - for k := range s.streams { - s.streams[k].sessionClose() - } - s.streamLock.Unlock() - return s.conn.Close() - } else { - return io.ErrClosedPipe - } -} - -// notifyBucket notifies recvLoop that bucket is available -func (s *Session) notifyBucket() { - select { - case s.bucketNotify <- struct{}{}: - default: - } -} - -func (s *Session) notifyReadError(err error) { - s.socketReadErrorOnce.Do(func() { - s.socketReadError.Store(err) - close(s.chSocketReadError) - }) -} - -func (s *Session) notifyWriteError(err error) { - s.socketWriteErrorOnce.Do(func() { - s.socketWriteError.Store(err) - close(s.chSocketWriteError) - }) -} - -func (s *Session) notifyProtoError(err error) { - s.protoErrorOnce.Do(func() { - s.protoError.Store(err) - close(s.chProtoError) - }) -} - -// IsClosed does a safe check to see if we have shutdown -func (s *Session) IsClosed() bool { - select { - case <-s.die: - return true - default: - return false - } -} - -// NumStreams returns the number of currently open streams -func (s *Session) NumStreams() int { - if s.IsClosed() { - return 0 - } - s.streamLock.Lock() - defer s.streamLock.Unlock() - return len(s.streams) -} - -// SetDeadline sets a deadline used by Accept* calls. -// A zero time value disables the deadline. -func (s *Session) SetDeadline(t time.Time) error { - s.deadline.Store(t) - return nil -} - -// LocalAddr satisfies net.Conn interface -func (s *Session) LocalAddr() net.Addr { - if ts, ok := s.conn.(interface { - LocalAddr() net.Addr - }); ok { - return ts.LocalAddr() - } - return nil -} - -// RemoteAddr satisfies net.Conn interface -func (s *Session) RemoteAddr() net.Addr { - if ts, ok := s.conn.(interface { - RemoteAddr() net.Addr - }); ok { - return ts.RemoteAddr() - } - return nil -} - -// notify the session that a stream has closed -func (s *Session) streamClosed(sid uint32) { - s.streamLock.Lock() - if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } - } - delete(s.streams, sid) - s.streamLock.Unlock() -} - -// returnTokens is called by stream to return token after read -func (s *Session) returnTokens(n int) { - if atomic.AddInt32(&s.bucket, int32(n)) > 0 { - s.notifyBucket() - } -} - -// recvLoop keeps on reading from underlying connection if tokens are available -func (s *Session) recvLoop() { - var hdr rawHeader - var updHdr updHeader - - for { - for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { - select { - case <-s.bucketNotify: - case <-s.die: - return - } - } - - // read header first - if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { - atomic.StoreInt32(&s.dataReady, 1) - if hdr.Version() != byte(s.config.Version) { - s.notifyProtoError(ErrInvalidProtocol) - return - } - sid := hdr.StreamID() - switch hdr.Cmd() { - case cmdNOP: - case cmdSYN: - s.streamLock.Lock() - if _, ok := s.streams[sid]; !ok { - stream := newStream(sid, s.config.MaxFrameSize, s) - s.streams[sid] = stream - select { - case s.chAccepts <- stream: - case <-s.die: - } - } - s.streamLock.Unlock() - case cmdFIN: - s.streamLock.Lock() - if stream, ok := s.streams[sid]; ok { - stream.fin() - stream.notifyReadEvent() - } - s.streamLock.Unlock() - case cmdPSH: - if hdr.Length() > 0 { - newbuf := defaultAllocator.Get(int(hdr.Length())) - if written, err := io.ReadFull(s.conn, newbuf); err == nil { - s.streamLock.Lock() - if stream, ok := s.streams[sid]; ok { - stream.pushBytes(newbuf) - atomic.AddInt32(&s.bucket, -int32(written)) - stream.notifyReadEvent() - } - s.streamLock.Unlock() - } else { - s.notifyReadError(err) - return - } - } - case cmdUPD: - if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { - s.streamLock.Lock() - if stream, ok := s.streams[sid]; ok { - stream.update(updHdr.Consumed(), updHdr.Window()) - } - s.streamLock.Unlock() - } else { - s.notifyReadError(err) - return - } - default: - s.notifyProtoError(ErrInvalidProtocol) - return - } - } else { - s.notifyReadError(err) - return - } - } -} - -func (s *Session) keepalive() { - tickerPing := time.NewTicker(s.config.KeepAliveInterval) - tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) - defer tickerPing.Stop() - defer tickerTimeout.Stop() - for { - select { - case <-tickerPing.C: - s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0) - s.notifyBucket() // force a signal to the recvLoop - case <-tickerTimeout.C: - if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { - // recvLoop may block while bucket is 0, in this case, - // session should not be closed. - if atomic.LoadInt32(&s.bucket) > 0 { - s.Close() - return - } - } - case <-s.die: - return - } - } -} - -// shaper shapes the sending sequence among streams -func (s *Session) shaperLoop() { - var reqs shaperHeap - var next writeRequest - var chWrite chan writeRequest - - for { - if len(reqs) > 0 { - chWrite = s.writes - next = heap.Pop(&reqs).(writeRequest) - } else { - chWrite = nil - } - - select { - case <-s.die: - return - case r := <-s.shaper: - if chWrite != nil { // next is valid, reshape - heap.Push(&reqs, next) - } - heap.Push(&reqs, r) - case chWrite <- next: - } - } -} - -func (s *Session) sendLoop() { - var buf []byte - var n int - var err error - var vec [][]byte // vector for writeBuffers - - bw, ok := s.conn.(buffersWriter) - if ok { - buf = make([]byte, headerSize) - vec = make([][]byte, 2) - } else { - buf = make([]byte, (1<<16)+headerSize) - } - - for { - select { - case <-s.die: - return - case request := <-s.writes: - buf[0] = request.frame.ver - buf[1] = request.frame.cmd - binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) - binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) - - if len(vec) > 0 { - vec[0] = buf[:headerSize] - vec[1] = request.frame.data - n, err = bw.WriteBuffers(vec) - } else { - copy(buf[headerSize:], request.frame.data) - n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) - } - - n -= headerSize - if n < 0 { - n = 0 - } - - result := writeResult{ - n: n, - err: err, - } - - request.result <- result - close(request.result) - - // store conn error - if err != nil { - s.notifyWriteError(err) - return - } - } - } -} - -// writeFrame writes the frame to the underlying connection -// and returns the number of bytes written if successful -func (s *Session) writeFrame(f Frame) (n int, err error) { - return s.writeFrameInternal(f, nil, 0) -} - -// internal writeFrame version to support deadline used in keepalive -func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { - req := writeRequest{ - prio: prio, - frame: f, - result: make(chan writeResult, 1), - } - select { - case s.shaper <- req: - case <-s.die: - return 0, io.ErrClosedPipe - case <-s.chSocketWriteError: - return 0, s.socketWriteError.Load().(error) - case <-deadline: - return 0, ErrTimeout - } - - select { - case result := <-req.result: - return result.n, result.err - case <-s.die: - return 0, io.ErrClosedPipe - case <-s.chSocketWriteError: - return 0, s.socketWriteError.Load().(error) - case <-deadline: - return 0, ErrTimeout - } -} |