package smux import ( "container/heap" "encoding/binary" "errors" "io" "net" "sync" "sync/atomic" "time" ) const ( defaultAcceptBacklog = 1024 ) var ( ErrInvalidProtocol = errors.New("invalid protocol") ErrConsumed = errors.New("peer consumed more than sent") ErrGoAway = errors.New("stream id overflows, should start a new connection") ErrTimeout = errors.New("timeout") ErrWouldBlock = errors.New("operation would block on IO") ) type writeRequest struct { prio uint64 frame Frame result chan writeResult } type writeResult struct { n int err error } type buffersWriter interface { WriteBuffers(v [][]byte) (n int, err error) } // Session defines a multiplexed connection for streams type Session struct { conn io.ReadWriteCloser config *Config nextStreamID uint32 // next stream identifier nextStreamIDLock sync.Mutex bucket int32 // token bucket bucketNotify chan struct{} // used for waiting for tokens streams map[uint32]*Stream // all streams in this session streamLock sync.Mutex // locks streams die chan struct{} // flag session has died dieOnce sync.Once // socket error handling socketReadError atomic.Value socketWriteError atomic.Value chSocketReadError chan struct{} chSocketWriteError chan struct{} socketReadErrorOnce sync.Once socketWriteErrorOnce sync.Once // smux protocol errors protoError atomic.Value chProtoError chan struct{} protoErrorOnce sync.Once chAccepts chan *Stream dataReady int32 // flag data has arrived goAway int32 // flag id exhausted deadline atomic.Value shaper chan writeRequest // a shaper for writing writes chan writeRequest } func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { s := new(Session) s.die = make(chan struct{}) s.conn = conn s.config = config s.streams = make(map[uint32]*Stream) s.chAccepts = make(chan *Stream, defaultAcceptBacklog) s.bucket = int32(config.MaxReceiveBuffer) s.bucketNotify = make(chan struct{}, 1) s.shaper = make(chan writeRequest) s.writes = make(chan writeRequest) s.chSocketReadError = make(chan struct{}) s.chSocketWriteError = make(chan struct{}) s.chProtoError = make(chan struct{}) if client { s.nextStreamID = 1 } else { s.nextStreamID = 0 } go s.shaperLoop() go s.recvLoop() go s.sendLoop() if !config.KeepAliveDisabled { go s.keepalive() } return s } // OpenStream is used to create a new stream func (s *Session) OpenStream() (*Stream, error) { if s.IsClosed() { return nil, io.ErrClosedPipe } // generate stream id s.nextStreamIDLock.Lock() if s.goAway > 0 { s.nextStreamIDLock.Unlock() return nil, ErrGoAway } s.nextStreamID += 2 sid := s.nextStreamID if sid == sid%2 { // stream-id overflows s.goAway = 1 s.nextStreamIDLock.Unlock() return nil, ErrGoAway } s.nextStreamIDLock.Unlock() stream := newStream(sid, s.config.MaxFrameSize, s) if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { return nil, err } s.streamLock.Lock() defer s.streamLock.Unlock() select { case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chSocketWriteError: return nil, s.socketWriteError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe default: s.streams[sid] = stream return stream, nil } } // Open returns a generic ReadWriteCloser func (s *Session) Open() (io.ReadWriteCloser, error) { return s.OpenStream() } // AcceptStream is used to block until the next available stream // is ready to be accepted. func (s *Session) AcceptStream() (*Stream, error) { var deadline <-chan time.Time if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(time.Until(d)) defer timer.Stop() deadline = timer.C } select { case stream := <-s.chAccepts: return stream, nil case <-deadline: return nil, ErrTimeout case <-s.chSocketReadError: return nil, s.socketReadError.Load().(error) case <-s.chProtoError: return nil, s.protoError.Load().(error) case <-s.die: return nil, io.ErrClosedPipe } } // Accept Returns a generic ReadWriteCloser instead of smux.Stream func (s *Session) Accept() (io.ReadWriteCloser, error) { return s.AcceptStream() } // Close is used to close the session and all streams. func (s *Session) Close() error { var once bool s.dieOnce.Do(func() { close(s.die) once = true }) if once { s.streamLock.Lock() for k := range s.streams { s.streams[k].sessionClose() } s.streamLock.Unlock() return s.conn.Close() } else { return io.ErrClosedPipe } } // notifyBucket notifies recvLoop that bucket is available func (s *Session) notifyBucket() { select { case s.bucketNotify <- struct{}{}: default: } } func (s *Session) notifyReadError(err error) { s.socketReadErrorOnce.Do(func() { s.socketReadError.Store(err) close(s.chSocketReadError) }) } func (s *Session) notifyWriteError(err error) { s.socketWriteErrorOnce.Do(func() { s.socketWriteError.Store(err) close(s.chSocketWriteError) }) } func (s *Session) notifyProtoError(err error) { s.protoErrorOnce.Do(func() { s.protoError.Store(err) close(s.chProtoError) }) } // IsClosed does a safe check to see if we have shutdown func (s *Session) IsClosed() bool { select { case <-s.die: return true default: return false } } // NumStreams returns the number of currently open streams func (s *Session) NumStreams() int { if s.IsClosed() { return 0 } s.streamLock.Lock() defer s.streamLock.Unlock() return len(s.streams) } // SetDeadline sets a deadline used by Accept* calls. // A zero time value disables the deadline. func (s *Session) SetDeadline(t time.Time) error { s.deadline.Store(t) return nil } // LocalAddr satisfies net.Conn interface func (s *Session) LocalAddr() net.Addr { if ts, ok := s.conn.(interface { LocalAddr() net.Addr }); ok { return ts.LocalAddr() } return nil } // RemoteAddr satisfies net.Conn interface func (s *Session) RemoteAddr() net.Addr { if ts, ok := s.conn.(interface { RemoteAddr() net.Addr }); ok { return ts.RemoteAddr() } return nil } // notify the session that a stream has closed func (s *Session) streamClosed(sid uint32) { s.streamLock.Lock() if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket if atomic.AddInt32(&s.bucket, int32(n)) > 0 { s.notifyBucket() } } delete(s.streams, sid) s.streamLock.Unlock() } // returnTokens is called by stream to return token after read func (s *Session) returnTokens(n int) { if atomic.AddInt32(&s.bucket, int32(n)) > 0 { s.notifyBucket() } } // recvLoop keeps on reading from underlying connection if tokens are available func (s *Session) recvLoop() { var hdr rawHeader var updHdr updHeader for { for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { select { case <-s.bucketNotify: case <-s.die: return } } // read header first if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { atomic.StoreInt32(&s.dataReady, 1) if hdr.Version() != byte(s.config.Version) { s.notifyProtoError(ErrInvalidProtocol) return } sid := hdr.StreamID() switch hdr.Cmd() { case cmdNOP: case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[sid]; !ok { stream := newStream(sid, s.config.MaxFrameSize, s) s.streams[sid] = stream select { case s.chAccepts <- stream: case <-s.die: } } s.streamLock.Unlock() case cmdFIN: s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { stream.fin() stream.notifyReadEvent() } s.streamLock.Unlock() case cmdPSH: if hdr.Length() > 0 { newbuf := defaultAllocator.Get(int(hdr.Length())) if written, err := io.ReadFull(s.conn, newbuf); err == nil { s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { stream.pushBytes(newbuf) atomic.AddInt32(&s.bucket, -int32(written)) stream.notifyReadEvent() } s.streamLock.Unlock() } else { s.notifyReadError(err) return } } case cmdUPD: if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { s.streamLock.Lock() if stream, ok := s.streams[sid]; ok { stream.update(updHdr.Consumed(), updHdr.Window()) } s.streamLock.Unlock() } else { s.notifyReadError(err) return } default: s.notifyProtoError(ErrInvalidProtocol) return } } else { s.notifyReadError(err) return } } } func (s *Session) keepalive() { tickerPing := time.NewTicker(s.config.KeepAliveInterval) tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) defer tickerPing.Stop() defer tickerTimeout.Stop() for { select { case <-tickerPing.C: s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0) s.notifyBucket() // force a signal to the recvLoop case <-tickerTimeout.C: if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { // recvLoop may block while bucket is 0, in this case, // session should not be closed. if atomic.LoadInt32(&s.bucket) > 0 { s.Close() return } } case <-s.die: return } } } // shaper shapes the sending sequence among streams func (s *Session) shaperLoop() { var reqs shaperHeap var next writeRequest var chWrite chan writeRequest for { if len(reqs) > 0 { chWrite = s.writes next = heap.Pop(&reqs).(writeRequest) } else { chWrite = nil } select { case <-s.die: return case r := <-s.shaper: if chWrite != nil { // next is valid, reshape heap.Push(&reqs, next) } heap.Push(&reqs, r) case chWrite <- next: } } } func (s *Session) sendLoop() { var buf []byte var n int var err error var vec [][]byte // vector for writeBuffers bw, ok := s.conn.(buffersWriter) if ok { buf = make([]byte, headerSize) vec = make([][]byte, 2) } else { buf = make([]byte, (1<<16)+headerSize) } for { select { case <-s.die: return case request := <-s.writes: buf[0] = request.frame.ver buf[1] = request.frame.cmd binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) if len(vec) > 0 { vec[0] = buf[:headerSize] vec[1] = request.frame.data n, err = bw.WriteBuffers(vec) } else { copy(buf[headerSize:], request.frame.data) n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) } n -= headerSize if n < 0 { n = 0 } result := writeResult{ n: n, err: err, } request.result <- result close(request.result) // store conn error if err != nil { s.notifyWriteError(err) return } } } } // writeFrame writes the frame to the underlying connection // and returns the number of bytes written if successful func (s *Session) writeFrame(f Frame) (n int, err error) { return s.writeFrameInternal(f, nil, 0) } // internal writeFrame version to support deadline used in keepalive func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { req := writeRequest{ prio: prio, frame: f, result: make(chan writeResult, 1), } select { case s.shaper <- req: case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } select { case result := <-req.result: return result.n, result.err case <-s.die: return 0, io.ErrClosedPipe case <-s.chSocketWriteError: return 0, s.socketWriteError.Load().(error) case <-deadline: return 0, ErrTimeout } }