summaryrefslogtreecommitdiff
path: root/vendor/github.com/xtaci/smux/stream.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/xtaci/smux/stream.go')
-rw-r--r--vendor/github.com/xtaci/smux/stream.go549
1 files changed, 549 insertions, 0 deletions
diff --git a/vendor/github.com/xtaci/smux/stream.go b/vendor/github.com/xtaci/smux/stream.go
new file mode 100644
index 0000000..6c9499c
--- /dev/null
+++ b/vendor/github.com/xtaci/smux/stream.go
@@ -0,0 +1,549 @@
+package smux
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Stream implements net.Conn
+type Stream struct {
+ id uint32
+ sess *Session
+
+ buffers [][]byte
+ heads [][]byte // slice heads kept for recycle
+
+ bufferLock sync.Mutex
+ frameSize int
+
+ // notify a read event
+ chReadEvent chan struct{}
+
+ // flag the stream has closed
+ die chan struct{}
+ dieOnce sync.Once
+
+ // FIN command
+ chFinEvent chan struct{}
+ finEventOnce sync.Once
+
+ // deadlines
+ readDeadline atomic.Value
+ writeDeadline atomic.Value
+
+ // per stream sliding window control
+ numRead uint32 // number of consumed bytes
+ numWritten uint32 // count num of bytes written
+ incr uint32 // counting for sending
+
+ // UPD command
+ peerConsumed uint32 // num of bytes the peer has consumed
+ peerWindow uint32 // peer window, initialized to 256KB, updated by peer
+ chUpdate chan struct{} // notify of remote data consuming and window update
+}
+
+// newStream initiates a Stream struct
+func newStream(id uint32, frameSize int, sess *Session) *Stream {
+ s := new(Stream)
+ s.id = id
+ s.chReadEvent = make(chan struct{}, 1)
+ s.chUpdate = make(chan struct{}, 1)
+ s.frameSize = frameSize
+ s.sess = sess
+ s.die = make(chan struct{})
+ s.chFinEvent = make(chan struct{})
+ s.peerWindow = initialPeerWindow // set to initial window size
+ return s
+}
+
+// ID returns the unique stream ID.
+func (s *Stream) ID() uint32 {
+ return s.id
+}
+
+// Read implements net.Conn
+func (s *Stream) Read(b []byte) (n int, err error) {
+ for {
+ n, err = s.tryRead(b)
+ if err == ErrWouldBlock {
+ if ew := s.waitRead(); ew != nil {
+ return 0, ew
+ }
+ } else {
+ return n, err
+ }
+ }
+}
+
+// tryRead is the nonblocking version of Read
+func (s *Stream) tryRead(b []byte) (n int, err error) {
+ if s.sess.config.Version == 2 {
+ return s.tryReadv2(b)
+ }
+
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ n = copy(b, s.buffers[0])
+ s.buffers[0] = s.buffers[0][n:]
+ if len(s.buffers[0]) == 0 {
+ s.buffers[0] = nil
+ s.buffers = s.buffers[1:]
+ // full recycle
+ defaultAllocator.Put(s.heads[0])
+ s.heads = s.heads[1:]
+ }
+ }
+ s.bufferLock.Unlock()
+
+ if n > 0 {
+ s.sess.returnTokens(n)
+ return n, nil
+ }
+
+ select {
+ case <-s.die:
+ return 0, io.EOF
+ default:
+ return 0, ErrWouldBlock
+ }
+}
+
+func (s *Stream) tryReadv2(b []byte) (n int, err error) {
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ var notifyConsumed uint32
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ n = copy(b, s.buffers[0])
+ s.buffers[0] = s.buffers[0][n:]
+ if len(s.buffers[0]) == 0 {
+ s.buffers[0] = nil
+ s.buffers = s.buffers[1:]
+ // full recycle
+ defaultAllocator.Put(s.heads[0])
+ s.heads = s.heads[1:]
+ }
+ }
+
+ // in an ideal environment:
+ // if more than half of buffer has consumed, send read ack to peer
+ // based on round-trip time of ACK, continous flowing data
+ // won't slow down because of waiting for ACK, as long as the
+ // consumer keeps on reading data
+ // s.numRead == n also notify window at the first read
+ s.numRead += uint32(n)
+ s.incr += uint32(n)
+ if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
+ notifyConsumed = s.numRead
+ s.incr = 0
+ }
+ s.bufferLock.Unlock()
+
+ if n > 0 {
+ s.sess.returnTokens(n)
+ if notifyConsumed > 0 {
+ err := s.sendWindowUpdate(notifyConsumed)
+ return n, err
+ } else {
+ return n, nil
+ }
+ }
+
+ select {
+ case <-s.die:
+ return 0, io.EOF
+ default:
+ return 0, ErrWouldBlock
+ }
+}
+
+// WriteTo implements io.WriteTo
+func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
+ if s.sess.config.Version == 2 {
+ return s.writeTov2(w)
+ }
+
+ for {
+ var buf []byte
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ buf = s.buffers[0]
+ s.buffers = s.buffers[1:]
+ s.heads = s.heads[1:]
+ }
+ s.bufferLock.Unlock()
+
+ if buf != nil {
+ nw, ew := w.Write(buf)
+ s.sess.returnTokens(len(buf))
+ defaultAllocator.Put(buf)
+ if nw > 0 {
+ n += int64(nw)
+ }
+
+ if ew != nil {
+ return n, ew
+ }
+ } else if ew := s.waitRead(); ew != nil {
+ return n, ew
+ }
+ }
+}
+
+func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
+ for {
+ var notifyConsumed uint32
+ var buf []byte
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ buf = s.buffers[0]
+ s.buffers = s.buffers[1:]
+ s.heads = s.heads[1:]
+ }
+ s.numRead += uint32(len(buf))
+ s.incr += uint32(len(buf))
+ if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
+ notifyConsumed = s.numRead
+ s.incr = 0
+ }
+ s.bufferLock.Unlock()
+
+ if buf != nil {
+ nw, ew := w.Write(buf)
+ s.sess.returnTokens(len(buf))
+ defaultAllocator.Put(buf)
+ if nw > 0 {
+ n += int64(nw)
+ }
+
+ if ew != nil {
+ return n, ew
+ }
+
+ if notifyConsumed > 0 {
+ if err := s.sendWindowUpdate(notifyConsumed); err != nil {
+ return n, err
+ }
+ }
+ } else if ew := s.waitRead(); ew != nil {
+ return n, ew
+ }
+ }
+}
+
+func (s *Stream) sendWindowUpdate(consumed uint32) error {
+ var timer *time.Timer
+ var deadline <-chan time.Time
+ if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer = time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
+ var hdr updHeader
+ binary.LittleEndian.PutUint32(hdr[:], consumed)
+ binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
+ frame.data = hdr[:]
+ _, err := s.sess.writeFrameInternal(frame, deadline, 0)
+ return err
+}
+
+func (s *Stream) waitRead() error {
+ var timer *time.Timer
+ var deadline <-chan time.Time
+ if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer = time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ select {
+ case <-s.chReadEvent:
+ return nil
+ case <-s.chFinEvent:
+ // BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82
+ s.bufferLock.Lock()
+ defer s.bufferLock.Unlock()
+ if len(s.buffers) > 0 {
+ return nil
+ }
+ return io.EOF
+ case <-s.sess.chSocketReadError:
+ return s.sess.socketReadError.Load().(error)
+ case <-s.sess.chProtoError:
+ return s.sess.protoError.Load().(error)
+ case <-deadline:
+ return ErrTimeout
+ case <-s.die:
+ return io.ErrClosedPipe
+ }
+
+}
+
+// Write implements net.Conn
+//
+// Note that the behavior when multiple goroutines write concurrently is not deterministic,
+// frames may interleave in random way.
+func (s *Stream) Write(b []byte) (n int, err error) {
+ if s.sess.config.Version == 2 {
+ return s.writeV2(b)
+ }
+
+ var deadline <-chan time.Time
+ if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ // check if stream has closed
+ select {
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ default:
+ }
+
+ // frame split and transmit
+ sent := 0
+ frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
+ bts := b
+ for len(bts) > 0 {
+ sz := len(bts)
+ if sz > s.frameSize {
+ sz = s.frameSize
+ }
+ frame.data = bts[:sz]
+ bts = bts[sz:]
+ n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten))
+ s.numWritten++
+ sent += n
+ if err != nil {
+ return sent, err
+ }
+ }
+
+ return sent, nil
+}
+
+func (s *Stream) writeV2(b []byte) (n int, err error) {
+ // check empty input
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ // check if stream has closed
+ select {
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ default:
+ }
+
+ // create write deadline timer
+ var deadline <-chan time.Time
+ if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ // frame split and transmit process
+ sent := 0
+ frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
+
+ for {
+ // per stream sliding window control
+ // [.... [consumed... numWritten] ... win... ]
+ // [.... [consumed...................+rmtwnd]]
+ var bts []byte
+ // note:
+ // even if uint32 overflow, this math still works:
+ // eg1: uint32(0) - uint32(math.MaxUint32) = 1
+ // eg2: int32(uint32(0) - uint32(1)) = -1
+ // security check for misbehavior
+ inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
+ if inflight < 0 {
+ return 0, ErrConsumed
+ }
+
+ win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
+ if win > 0 {
+ if win > int32(len(b)) {
+ bts = b
+ b = nil
+ } else {
+ bts = b[:win]
+ b = b[win:]
+ }
+
+ for len(bts) > 0 {
+ sz := len(bts)
+ if sz > s.frameSize {
+ sz = s.frameSize
+ }
+ frame.data = bts[:sz]
+ bts = bts[sz:]
+ n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
+ atomic.AddUint32(&s.numWritten, uint32(sz))
+ sent += n
+ if err != nil {
+ return sent, err
+ }
+ }
+ }
+
+ // if there is any data remaining to be sent
+ // wait until stream closes, window changes or deadline reached
+ // this blocking behavior will inform upper layer to do flow control
+ if len(b) > 0 {
+ select {
+ case <-s.chFinEvent: // if fin arrived, future window update is impossible
+ return 0, io.EOF
+ case <-s.die:
+ return sent, io.ErrClosedPipe
+ case <-deadline:
+ return sent, ErrTimeout
+ case <-s.sess.chSocketWriteError:
+ return sent, s.sess.socketWriteError.Load().(error)
+ case <-s.chUpdate:
+ continue
+ }
+ } else {
+ return sent, nil
+ }
+ }
+}
+
+// Close implements net.Conn
+func (s *Stream) Close() error {
+ var once bool
+ var err error
+ s.dieOnce.Do(func() {
+ close(s.die)
+ once = true
+ })
+
+ if once {
+ _, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id))
+ s.sess.streamClosed(s.id)
+ return err
+ } else {
+ return io.ErrClosedPipe
+ }
+}
+
+// GetDieCh returns a readonly chan which can be readable
+// when the stream is to be closed.
+func (s *Stream) GetDieCh() <-chan struct{} {
+ return s.die
+}
+
+// SetReadDeadline sets the read deadline as defined by
+// net.Conn.SetReadDeadline.
+// A zero time value disables the deadline.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+ s.readDeadline.Store(t)
+ s.notifyReadEvent()
+ return nil
+}
+
+// SetWriteDeadline sets the write deadline as defined by
+// net.Conn.SetWriteDeadline.
+// A zero time value disables the deadline.
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+ s.writeDeadline.Store(t)
+ return nil
+}
+
+// SetDeadline sets both read and write deadlines as defined by
+// net.Conn.SetDeadline.
+// A zero time value disables the deadlines.
+func (s *Stream) SetDeadline(t time.Time) error {
+ if err := s.SetReadDeadline(t); err != nil {
+ return err
+ }
+ if err := s.SetWriteDeadline(t); err != nil {
+ return err
+ }
+ return nil
+}
+
+// session closes
+func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
+
+// LocalAddr satisfies net.Conn interface
+func (s *Stream) LocalAddr() net.Addr {
+ if ts, ok := s.sess.conn.(interface {
+ LocalAddr() net.Addr
+ }); ok {
+ return ts.LocalAddr()
+ }
+ return nil
+}
+
+// RemoteAddr satisfies net.Conn interface
+func (s *Stream) RemoteAddr() net.Addr {
+ if ts, ok := s.sess.conn.(interface {
+ RemoteAddr() net.Addr
+ }); ok {
+ return ts.RemoteAddr()
+ }
+ return nil
+}
+
+// pushBytes append buf to buffers
+func (s *Stream) pushBytes(buf []byte) (written int, err error) {
+ s.bufferLock.Lock()
+ s.buffers = append(s.buffers, buf)
+ s.heads = append(s.heads, buf)
+ s.bufferLock.Unlock()
+ return
+}
+
+// recycleTokens transform remaining bytes to tokens(will truncate buffer)
+func (s *Stream) recycleTokens() (n int) {
+ s.bufferLock.Lock()
+ for k := range s.buffers {
+ n += len(s.buffers[k])
+ defaultAllocator.Put(s.heads[k])
+ }
+ s.buffers = nil
+ s.heads = nil
+ s.bufferLock.Unlock()
+ return
+}
+
+// notify read event
+func (s *Stream) notifyReadEvent() {
+ select {
+ case s.chReadEvent <- struct{}{}:
+ default:
+ }
+}
+
+// update command
+func (s *Stream) update(consumed uint32, window uint32) {
+ atomic.StoreUint32(&s.peerConsumed, consumed)
+ atomic.StoreUint32(&s.peerWindow, window)
+ select {
+ case s.chUpdate <- struct{}{}:
+ default:
+ }
+}
+
+// mark this stream has been closed in protocol
+func (s *Stream) fin() {
+ s.finEventOnce.Do(func() {
+ close(s.chFinEvent)
+ })
+}