package smux import ( "encoding/binary" "io" "net" "sync" "sync/atomic" "time" ) // Stream implements net.Conn type Stream struct { id uint32 sess *Session buffers [][]byte heads [][]byte // slice heads kept for recycle bufferLock sync.Mutex frameSize int // notify a read event chReadEvent chan struct{} // flag the stream has closed die chan struct{} dieOnce sync.Once // FIN command chFinEvent chan struct{} finEventOnce sync.Once // deadlines readDeadline atomic.Value writeDeadline atomic.Value // per stream sliding window control numRead uint32 // number of consumed bytes numWritten uint32 // count num of bytes written incr uint32 // counting for sending // UPD command peerConsumed uint32 // num of bytes the peer has consumed peerWindow uint32 // peer window, initialized to 256KB, updated by peer chUpdate chan struct{} // notify of remote data consuming and window update } // newStream initiates a Stream struct func newStream(id uint32, frameSize int, sess *Session) *Stream { s := new(Stream) s.id = id s.chReadEvent = make(chan struct{}, 1) s.chUpdate = make(chan struct{}, 1) s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) s.chFinEvent = make(chan struct{}) s.peerWindow = initialPeerWindow // set to initial window size return s } // ID returns the unique stream ID. func (s *Stream) ID() uint32 { return s.id } // Read implements net.Conn func (s *Stream) Read(b []byte) (n int, err error) { for { n, err = s.tryRead(b) if err == ErrWouldBlock { if ew := s.waitRead(); ew != nil { return 0, ew } } else { return n, err } } } // tryRead is the nonblocking version of Read func (s *Stream) tryRead(b []byte) (n int, err error) { if s.sess.config.Version == 2 { return s.tryReadv2(b) } if len(b) == 0 { return 0, nil } s.bufferLock.Lock() if len(s.buffers) > 0 { n = copy(b, s.buffers[0]) s.buffers[0] = s.buffers[0][n:] if len(s.buffers[0]) == 0 { s.buffers[0] = nil s.buffers = s.buffers[1:] // full recycle defaultAllocator.Put(s.heads[0]) s.heads = s.heads[1:] } } s.bufferLock.Unlock() if n > 0 { s.sess.returnTokens(n) return n, nil } select { case <-s.die: return 0, io.EOF default: return 0, ErrWouldBlock } } func (s *Stream) tryReadv2(b []byte) (n int, err error) { if len(b) == 0 { return 0, nil } var notifyConsumed uint32 s.bufferLock.Lock() if len(s.buffers) > 0 { n = copy(b, s.buffers[0]) s.buffers[0] = s.buffers[0][n:] if len(s.buffers[0]) == 0 { s.buffers[0] = nil s.buffers = s.buffers[1:] // full recycle defaultAllocator.Put(s.heads[0]) s.heads = s.heads[1:] } } // in an ideal environment: // if more than half of buffer has consumed, send read ack to peer // based on round-trip time of ACK, continous flowing data // won't slow down because of waiting for ACK, as long as the // consumer keeps on reading data // s.numRead == n also notify window at the first read s.numRead += uint32(n) s.incr += uint32(n) if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) { notifyConsumed = s.numRead s.incr = 0 } s.bufferLock.Unlock() if n > 0 { s.sess.returnTokens(n) if notifyConsumed > 0 { err := s.sendWindowUpdate(notifyConsumed) return n, err } else { return n, nil } } select { case <-s.die: return 0, io.EOF default: return 0, ErrWouldBlock } } // WriteTo implements io.WriteTo func (s *Stream) WriteTo(w io.Writer) (n int64, err error) { if s.sess.config.Version == 2 { return s.writeTov2(w) } for { var buf []byte s.bufferLock.Lock() if len(s.buffers) > 0 { buf = s.buffers[0] s.buffers = s.buffers[1:] s.heads = s.heads[1:] } s.bufferLock.Unlock() if buf != nil { nw, ew := w.Write(buf) s.sess.returnTokens(len(buf)) defaultAllocator.Put(buf) if nw > 0 { n += int64(nw) } if ew != nil { return n, ew } } else if ew := s.waitRead(); ew != nil { return n, ew } } } func (s *Stream) writeTov2(w io.Writer) (n int64, err error) { for { var notifyConsumed uint32 var buf []byte s.bufferLock.Lock() if len(s.buffers) > 0 { buf = s.buffers[0] s.buffers = s.buffers[1:] s.heads = s.heads[1:] } s.numRead += uint32(len(buf)) s.incr += uint32(len(buf)) if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) { notifyConsumed = s.numRead s.incr = 0 } s.bufferLock.Unlock() if buf != nil { nw, ew := w.Write(buf) s.sess.returnTokens(len(buf)) defaultAllocator.Put(buf) if nw > 0 { n += int64(nw) } if ew != nil { return n, ew } if notifyConsumed > 0 { if err := s.sendWindowUpdate(notifyConsumed); err != nil { return n, err } } } else if ew := s.waitRead(); ew != nil { return n, ew } } } func (s *Stream) sendWindowUpdate(consumed uint32) error { var timer *time.Timer var deadline <-chan time.Time if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { timer = time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id) var hdr updHeader binary.LittleEndian.PutUint32(hdr[:], consumed) binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer)) frame.data = hdr[:] _, err := s.sess.writeFrameInternal(frame, deadline, 0) return err } func (s *Stream) waitRead() error { var timer *time.Timer var deadline <-chan time.Time if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { timer = time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } select { case <-s.chReadEvent: return nil case <-s.chFinEvent: // BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82 s.bufferLock.Lock() defer s.bufferLock.Unlock() if len(s.buffers) > 0 { return nil } return io.EOF case <-s.sess.chSocketReadError: return s.sess.socketReadError.Load().(error) case <-s.sess.chProtoError: return s.sess.protoError.Load().(error) case <-deadline: return ErrTimeout case <-s.die: return io.ErrClosedPipe } } // Write implements net.Conn // // Note that the behavior when multiple goroutines write concurrently is not deterministic, // frames may interleave in random way. func (s *Stream) Write(b []byte) (n int, err error) { if s.sess.config.Version == 2 { return s.writeV2(b) } var deadline <-chan time.Time if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } // check if stream has closed select { case <-s.die: return 0, io.ErrClosedPipe default: } // frame split and transmit sent := 0 frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) bts := b for len(bts) > 0 { sz := len(bts) if sz > s.frameSize { sz = s.frameSize } frame.data = bts[:sz] bts = bts[sz:] n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten)) s.numWritten++ sent += n if err != nil { return sent, err } } return sent, nil } func (s *Stream) writeV2(b []byte) (n int, err error) { // check empty input if len(b) == 0 { return 0, nil } // check if stream has closed select { case <-s.die: return 0, io.ErrClosedPipe default: } // create write deadline timer var deadline <-chan time.Time if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } // frame split and transmit process sent := 0 frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) for { // per stream sliding window control // [.... [consumed... numWritten] ... win... ] // [.... [consumed...................+rmtwnd]] var bts []byte // note: // even if uint32 overflow, this math still works: // eg1: uint32(0) - uint32(math.MaxUint32) = 1 // eg2: int32(uint32(0) - uint32(1)) = -1 // security check for misbehavior inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed)) if inflight < 0 { return 0, ErrConsumed } win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight if win > 0 { if win > int32(len(b)) { bts = b b = nil } else { bts = b[:win] b = b[win:] } for len(bts) > 0 { sz := len(bts) if sz > s.frameSize { sz = s.frameSize } frame.data = bts[:sz] bts = bts[sz:] n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten))) atomic.AddUint32(&s.numWritten, uint32(sz)) sent += n if err != nil { return sent, err } } } // if there is any data remaining to be sent // wait until stream closes, window changes or deadline reached // this blocking behavior will inform upper layer to do flow control if len(b) > 0 { select { case <-s.chFinEvent: // if fin arrived, future window update is impossible return 0, io.EOF case <-s.die: return sent, io.ErrClosedPipe case <-deadline: return sent, ErrTimeout case <-s.sess.chSocketWriteError: return sent, s.sess.socketWriteError.Load().(error) case <-s.chUpdate: continue } } else { return sent, nil } } } // Close implements net.Conn func (s *Stream) Close() error { var once bool var err error s.dieOnce.Do(func() { close(s.die) once = true }) if once { _, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id)) s.sess.streamClosed(s.id) return err } else { return io.ErrClosedPipe } } // GetDieCh returns a readonly chan which can be readable // when the stream is to be closed. func (s *Stream) GetDieCh() <-chan struct{} { return s.die } // SetReadDeadline sets the read deadline as defined by // net.Conn.SetReadDeadline. // A zero time value disables the deadline. func (s *Stream) SetReadDeadline(t time.Time) error { s.readDeadline.Store(t) s.notifyReadEvent() return nil } // SetWriteDeadline sets the write deadline as defined by // net.Conn.SetWriteDeadline. // A zero time value disables the deadline. func (s *Stream) SetWriteDeadline(t time.Time) error { s.writeDeadline.Store(t) return nil } // SetDeadline sets both read and write deadlines as defined by // net.Conn.SetDeadline. // A zero time value disables the deadlines. func (s *Stream) SetDeadline(t time.Time) error { if err := s.SetReadDeadline(t); err != nil { return err } if err := s.SetWriteDeadline(t); err != nil { return err } return nil } // session closes func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) } // LocalAddr satisfies net.Conn interface func (s *Stream) LocalAddr() net.Addr { if ts, ok := s.sess.conn.(interface { LocalAddr() net.Addr }); ok { return ts.LocalAddr() } return nil } // RemoteAddr satisfies net.Conn interface func (s *Stream) RemoteAddr() net.Addr { if ts, ok := s.sess.conn.(interface { RemoteAddr() net.Addr }); ok { return ts.RemoteAddr() } return nil } // pushBytes append buf to buffers func (s *Stream) pushBytes(buf []byte) (written int, err error) { s.bufferLock.Lock() s.buffers = append(s.buffers, buf) s.heads = append(s.heads, buf) s.bufferLock.Unlock() return } // recycleTokens transform remaining bytes to tokens(will truncate buffer) func (s *Stream) recycleTokens() (n int) { s.bufferLock.Lock() for k := range s.buffers { n += len(s.buffers[k]) defaultAllocator.Put(s.heads[k]) } s.buffers = nil s.heads = nil s.bufferLock.Unlock() return } // notify read event func (s *Stream) notifyReadEvent() { select { case s.chReadEvent <- struct{}{}: default: } } // update command func (s *Stream) update(consumed uint32, window uint32) { atomic.StoreUint32(&s.peerConsumed, consumed) atomic.StoreUint32(&s.peerWindow, window) select { case s.chUpdate <- struct{}{}: default: } } // mark this stream has been closed in protocol func (s *Stream) fin() { s.finEventOnce.Do(func() { close(s.chFinEvent) }) }