diff options
Diffstat (limited to 'vendor/github.com/pion/sctp/association.go')
-rw-r--r-- | vendor/github.com/pion/sctp/association.go | 2241 |
1 files changed, 2241 insertions, 0 deletions
diff --git a/vendor/github.com/pion/sctp/association.go b/vendor/github.com/pion/sctp/association.go new file mode 100644 index 0000000..1393cb8 --- /dev/null +++ b/vendor/github.com/pion/sctp/association.go @@ -0,0 +1,2241 @@ +package sctp + +import ( + "bytes" + "fmt" + "io" + "math" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/randutil" + "github.com/pkg/errors" +) + +// Use global random generator to properly seed by crypto grade random. +var ( + globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals + errChunk = errors.New("Abort chunk, with following errors") +) + +const ( + receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) + initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) + initialRecvBufSize uint32 = 1024 * 1024 + commonHeaderSize uint32 = 12 + dataChunkHeaderSize uint32 = 16 + defaultMaxMessageSize uint32 = 65536 +) + +// association state enums +const ( + closed uint32 = iota + cookieWait + cookieEchoed + established + shutdownAckSent + shutdownPending + shutdownReceived + shutdownSent +) + +// retransmission timer IDs +const ( + timerT1Init int = iota + timerT1Cookie + timerT3RTX + timerReconfig +) + +// ack mode (for testing) +const ( + ackModeNormal int = iota + ackModeNoDelay + ackModeAlwaysDelay +) + +// ack transmission state +const ( + ackStateIdle int = iota // ack timer is off + ackStateImmediate // ack timer is on (ack is being delayed) + ackStateDelay // will send ack immediately +) + +// other constants +const ( + acceptChSize = 16 +) + +func getAssociationStateString(a uint32) string { + switch a { + case closed: + return "Closed" + case cookieWait: + return "CookieWait" + case cookieEchoed: + return "CookieEchoed" + case established: + return "Established" + case shutdownPending: + return "ShutdownPending" + case shutdownSent: + return "ShutdownSent" + case shutdownReceived: + return "ShutdownReceived" + case shutdownAckSent: + return "ShutdownAckSent" + default: + return fmt.Sprintf("Invalid association state %d", a) + } +} + +// Association represents an SCTP association +// 13.2. Parameters Necessary per Association (i.e., the TCB) +// Peer : Tag value to be sent in every packet and is received +// Verification: in the INIT or INIT ACK chunk. +// Tag : +// +// My : Tag expected in every inbound packet and sent in the +// Verification: INIT or INIT ACK chunk. +// +// Tag : +// State : A state variable indicating what state the association +// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED, +// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, +// : SHUTDOWN-ACK-SENT. +// +// Note: No "CLOSED" state is illustrated since if a +// association is "CLOSED" its TCB SHOULD be removed. +type Association struct { + bytesReceived uint64 + bytesSent uint64 + + lock sync.RWMutex + + netConn net.Conn + + peerVerificationTag uint32 + myVerificationTag uint32 + state uint32 + myNextTSN uint32 // nextTSN + peerLastTSN uint32 // lastRcvdTSN + minTSN2MeasureRTT uint32 // for RTT measurement + willSendForwardTSN bool + willRetransmitFast bool + willRetransmitReconfig bool + + // Reconfig + myNextRSN uint32 + reconfigs map[uint32]*chunkReconfig + reconfigRequests map[uint32]*paramOutgoingResetRequest + + // Non-RFC internal data + sourcePort uint16 + destinationPort uint16 + myMaxNumInboundStreams uint16 + myMaxNumOutboundStreams uint16 + myCookie *paramStateCookie + payloadQueue *payloadQueue + inflightQueue *payloadQueue + pendingQueue *pendingQueue + controlQueue *controlQueue + mtu uint32 + maxPayloadSize uint32 // max DATA chunk payload size + cumulativeTSNAckPoint uint32 + advancedPeerTSNAckPoint uint32 + useForwardTSN bool + + // Congestion control parameters + maxReceiveBufferSize uint32 + maxMessageSize uint32 + cwnd uint32 // my congestion window size + rwnd uint32 // calculated peer's receiver windows size + ssthresh uint32 // slow start threshold + partialBytesAcked uint32 + inFastRecovery bool + fastRecoverExitPoint uint32 + + // RTX & Ack timer + rtoMgr *rtoManager + t1Init *rtxTimer + t1Cookie *rtxTimer + t3RTX *rtxTimer + tReconfig *rtxTimer + ackTimer *ackTimer + + // Chunks stored for retransmission + storedInit *chunkInit + storedCookieEcho *chunkCookieEcho + + streams map[uint16]*Stream + acceptCh chan *Stream + readLoopCloseCh chan struct{} + awakeWriteLoopCh chan struct{} + closeWriteLoopCh chan struct{} + handshakeCompletedCh chan error + + closeWriteLoopOnce sync.Once + + // local error + silentError error + + ackState int + ackMode int // for testing + + // stats + stats *associationStats + + // per inbound packet context + delayedAckTriggered bool + immediateAckTriggered bool + + name string + log logging.LeveledLogger +} + +// Config collects the arguments to createAssociation construction into +// a single structure +type Config struct { + NetConn net.Conn + MaxReceiveBufferSize uint32 + MaxMessageSize uint32 + LoggerFactory logging.LoggerFactory +} + +// Server accepts a SCTP stream over a conn +func Server(config Config) (*Association, error) { + a := createAssociation(config) + a.init(false) + + select { + case err := <-a.handshakeCompletedCh: + if err != nil { + return nil, err + } + return a, nil + case <-a.readLoopCloseCh: + return nil, errors.Errorf("association closed before connecting") + } +} + +// Client opens a SCTP stream over a conn +func Client(config Config) (*Association, error) { + a := createAssociation(config) + a.init(true) + + select { + case err := <-a.handshakeCompletedCh: + if err != nil { + return nil, err + } + return a, nil + case <-a.readLoopCloseCh: + return nil, errors.Errorf("association closed before connecting") + } +} + +func createAssociation(config Config) *Association { + var maxReceiveBufferSize uint32 + if config.MaxReceiveBufferSize == 0 { + maxReceiveBufferSize = initialRecvBufSize + } else { + maxReceiveBufferSize = config.MaxReceiveBufferSize + } + + var maxMessageSize uint32 + if config.MaxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } else { + maxMessageSize = config.MaxMessageSize + } + + tsn := globalMathRandomGenerator.Uint32() + a := &Association{ + netConn: config.NetConn, + maxReceiveBufferSize: maxReceiveBufferSize, + maxMessageSize: maxMessageSize, + myMaxNumOutboundStreams: math.MaxUint16, + myMaxNumInboundStreams: math.MaxUint16, + payloadQueue: newPayloadQueue(), + inflightQueue: newPayloadQueue(), + pendingQueue: newPendingQueue(), + controlQueue: newControlQueue(), + mtu: initialMTU, + maxPayloadSize: initialMTU - (commonHeaderSize + dataChunkHeaderSize), + myVerificationTag: globalMathRandomGenerator.Uint32(), + myNextTSN: tsn, + myNextRSN: tsn, + minTSN2MeasureRTT: tsn, + state: closed, + rtoMgr: newRTOManager(), + streams: map[uint16]*Stream{}, + reconfigs: map[uint32]*chunkReconfig{}, + reconfigRequests: map[uint32]*paramOutgoingResetRequest{}, + acceptCh: make(chan *Stream, acceptChSize), + readLoopCloseCh: make(chan struct{}), + awakeWriteLoopCh: make(chan struct{}, 1), + closeWriteLoopCh: make(chan struct{}), + handshakeCompletedCh: make(chan error), + cumulativeTSNAckPoint: tsn - 1, + advancedPeerTSNAckPoint: tsn - 1, + silentError: errors.Errorf("silently discard"), + stats: &associationStats{}, + log: config.LoggerFactory.NewLogger("sctp"), + } + + a.name = fmt.Sprintf("%p", a) + + // RFC 4690 Sec 7.2.1 + // o The initial cwnd before DATA transmission or after a sufficiently + // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 + // bytes)). + a.cwnd = min32(4*a.mtu, max32(2*a.mtu, 4380)) + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", + a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + + a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans) + a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans) + a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans) // retransmit forever + a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans) // retransmit forever + a.ackTimer = newAckTimer(a) + + return a +} + +func (a *Association) init(isClient bool) { + a.lock.Lock() + defer a.lock.Unlock() + + go a.readLoop() + go a.writeLoop() + + if isClient { + a.setState(cookieWait) + init := &chunkInit{} + init.initialTSN = a.myNextTSN + init.numOutboundStreams = a.myMaxNumOutboundStreams + init.numInboundStreams = a.myMaxNumInboundStreams + init.initiateTag = a.myVerificationTag + init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + setSupportedExtensions(&init.chunkInitCommon) + a.storedInit = init + + err := a.sendInit() + if err != nil { + a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) + } + + a.t1Init.start(a.rtoMgr.getRTO()) + } +} + +// caller must hold a.lock +func (a *Association) sendInit() error { + a.log.Debugf("[%s] sending INIT", a.name) + if a.storedInit == nil { + return errors.Errorf("the init not stored to send") + } + + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + a.sourcePort = 5000 // Spec?? + a.destinationPort = 5000 // Spec?? + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + + outbound.chunks = []chunk{a.storedInit} + + a.controlQueue.push(outbound) + a.awakeWriteLoop() + + return nil +} + +// caller must hold a.lock +func (a *Association) sendCookieEcho() error { + if a.storedCookieEcho == nil { + return errors.Errorf("cookieEcho not stored to send") + } + + a.log.Debugf("[%s] sending COOKIE-ECHO", a.name) + + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + outbound.chunks = []chunk{a.storedCookieEcho} + + a.controlQueue.push(outbound) + a.awakeWriteLoop() + + return nil +} + +// Close ends the SCTP Association and cleans up any state +func (a *Association) Close() error { + a.log.Debugf("[%s] closing association..", a.name) + + a.setState(closed) + + err := a.netConn.Close() + + a.closeAllTimers() + + // awake writeLoop to exit + a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) + + // Wait for readLoop to end + <-a.readLoopCloseCh + + a.log.Debugf("[%s] association closed", a.name) + a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) + a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) + a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) + a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) + a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) + return err +} + +func (a *Association) closeAllTimers() { + // Close all retransmission & ack timers + a.t1Init.close() + a.t1Cookie.close() + a.t3RTX.close() + a.tReconfig.close() + a.ackTimer.close() +} + +func (a *Association) readLoop() { + var closeErr error + defer func() { + // also stop writeLoop, otherwise writeLoop can be leaked + // if connection is lost when there is no writing packet. + a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) + + a.lock.Lock() + for _, s := range a.streams { + a.unregisterStream(s, closeErr) + } + a.lock.Unlock() + close(a.acceptCh) + close(a.readLoopCloseCh) + }() + + a.log.Debugf("[%s] readLoop entered", a.name) + buffer := make([]byte, receiveMTU) + + for { + n, err := a.netConn.Read(buffer) + if err != nil { + closeErr = err + break + } + // Make a buffer sized to what we read, then copy the data we + // read from the underlying transport. We do this because the + // user data is passed to the reassembly queue without + // copying. + inbound := make([]byte, n) + copy(inbound, buffer[:n]) + atomic.AddUint64(&a.bytesReceived, uint64(n)) + if err = a.handleInbound(inbound); err != nil { + closeErr = err + break + } + } + + a.log.Debugf("[%s] readLoop exited %s", a.name, closeErr) +} + +func (a *Association) writeLoop() { + a.log.Debugf("[%s] writeLoop entered", a.name) + +loop: + for { + rawPackets := a.gatherOutbound() + + for _, raw := range rawPackets { + _, err := a.netConn.Write(raw) + if err != nil { + if err != io.EOF { + a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err) + } + a.log.Debugf("[%s] writeLoop ended", a.name) + break loop + } + atomic.AddUint64(&a.bytesSent, uint64(len(raw))) + } + + select { + case <-a.awakeWriteLoopCh: + case <-a.closeWriteLoopCh: + break loop + } + } + + a.setState(closed) + a.closeAllTimers() + + a.log.Debugf("[%s] writeLoop exited", a.name) +} + +func (a *Association) awakeWriteLoop() { + select { + case a.awakeWriteLoopCh <- struct{}{}: + default: + } +} + +// unregisterStream un-registers a stream from the association +// The caller should hold the association write lock. +func (a *Association) unregisterStream(s *Stream, err error) { + s.lock.Lock() + defer s.lock.Unlock() + + delete(a.streams, s.streamIdentifier) + s.readErr = err + s.readNotifier.Broadcast() +} + +// handleInbound parses incoming raw packets +func (a *Association) handleInbound(raw []byte) error { + p := &packet{} + if err := p.unmarshal(raw); err != nil { + a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) + return nil + } + + if err := checkPacket(p); err != nil { + a.log.Warnf("[%s] failed validating packet %s", a.name, err) + return nil + } + + a.handleChunkStart() + + for _, c := range p.chunks { + if err := a.handleChunk(p, c); err != nil { + return err + } + } + + a.handleChunkEnd() + + return nil +} + +// The caller should hold the lock +func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) [][]byte { + for _, p := range a.getDataPacketsToRetransmit() { + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) + continue + } + rawPackets = append(rawPackets, raw) + } + + // Pop unsent data chunks from the pending queue to send as much as + // cwnd and rwnd allow. + chunks, sisToReset := a.popPendingDataChunksToSend() + if len(chunks) > 0 { + // Start timer. (noop if already started) + a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + for _, p := range a.bundleDataChunksIntoPackets(chunks) { + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) + continue + } + rawPackets = append(rawPackets, raw) + } + } + + if len(sisToReset) > 0 || a.willRetransmitReconfig { + if a.willRetransmitReconfig { + a.willRetransmitReconfig = false + a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) + for _, c := range a.reconfigs { + p := a.createPacket([]chunk{c}) + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + } + + if len(sisToReset) > 0 { + rsn := a.generateNextRSN() + tsn := a.myNextTSN - 1 + c := &chunkReconfig{ + paramA: ¶mOutgoingResetRequest{ + reconfigRequestSequenceNumber: rsn, + senderLastTSN: tsn, + streamIdentifiers: sisToReset, + }, + } + a.reconfigs[rsn] = c // store in the map for retransmission + a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v", + a.name, rsn, a.myNextTSN-1, sisToReset) + p := a.createPacket([]chunk{c}) + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + + if len(a.reconfigs) > 0 { + a.tReconfig.start(a.rtoMgr.getRTO()) + } + } + + return rawPackets +} + +// The caller should hold the lock +func (a *Association) gatherOutboundFrastRetransmissionPackets(rawPackets [][]byte) [][]byte { + if a.willRetransmitFast { + a.willRetransmitFast = false + + toFastRetrans := []chunk{} + fastRetransSize := commonHeaderSize + + for i := 0; ; i++ { + c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) + if !ok { + break // end of pending data + } + + if c.acked || c.abandoned() { + continue + } + + if c.nSent > 1 || c.missIndicator < 3 { + continue + } + + // RFC 4960 Sec 7.2.4 Fast Retransmit on Gap Reports + // 3) Determine how many of the earliest (i.e., lowest TSN) DATA chunks + // marked for retransmission will fit into a single packet, subject + // to constraint of the path MTU of the destination transport + // address to which the packet is being sent. Call this value K. + // Retransmit those K DATA chunks in a single packet. When a Fast + // Retransmit is being performed, the sender SHOULD ignore the value + // of cwnd and SHOULD NOT delay retransmission for this single + // packet. + + dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) + if a.mtu < fastRetransSize+dataChunkSize { + break + } + + fastRetransSize += dataChunkSize + a.stats.incFastRetrans() + c.nSent++ + a.checkPartialReliabilityStatus(c) + toFastRetrans = append(toFastRetrans, c) + a.log.Tracef("[%s] fast-retransmit: tsn=%d sent=%d htna=%d", + a.name, c.tsn, c.nSent, a.fastRecoverExitPoint) + } + + if len(toFastRetrans) > 0 { + raw, err := a.createPacket(toFastRetrans).marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + } + + return rawPackets +} + +// The caller should hold the lock +func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { + if a.ackState == ackStateImmediate { + a.ackState = ackStateIdle + sack := a.createSelectiveAckChunk() + a.log.Debugf("[%s] sending SACK: %s", a.name, sack.String()) + raw, err := a.createPacket([]chunk{sack}).marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a SACK packet", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + + return rawPackets +} + +// The caller should hold the lock +func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]byte { + if a.willSendForwardTSN { + a.willSendForwardTSN = false + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + fwdtsn := a.createForwardTSN() + raw, err := a.createPacket([]chunk{fwdtsn}).marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + } + + return rawPackets +} + +// gatherOutbound gathers outgoing packets +func (a *Association) gatherOutbound() [][]byte { + a.lock.Lock() + defer a.lock.Unlock() + + rawPackets := [][]byte{} + + if a.controlQueue.size() > 0 { + for _, p := range a.controlQueue.popAll() { + raw, err := p.marshal() + if err != nil { + a.log.Warnf("[%s] failed to serialize a control packet", a.name) + continue + } + rawPackets = append(rawPackets, raw) + } + } + + state := a.getState() + + if state == established { + rawPackets = a.gatherOutboundDataAndReconfigPackets(rawPackets) + rawPackets = a.gatherOutboundFrastRetransmissionPackets(rawPackets) + rawPackets = a.gatherOutboundSackPackets(rawPackets) + rawPackets = a.gatherOutboundForwardTSNPackets(rawPackets) + } + + return rawPackets +} + +func checkPacket(p *packet) error { + // All packets must adhere to these rules + + // This is the SCTP sender's port number. It can be used by the + // receiver in combination with the source IP address, the SCTP + // destination port, and possibly the destination IP address to + // identify the association to which this packet belongs. The port + // number 0 MUST NOT be used. + if p.sourcePort == 0 { + return errors.Errorf("sctp packet must not have a source port of 0") + } + + // This is the SCTP port number to which this packet is destined. + // The receiving host will use this port number to de-multiplex the + // SCTP packet to the correct receiving endpoint/application. The + // port number 0 MUST NOT be used. + if p.destinationPort == 0 { + return errors.Errorf("sctp packet must not have a destination port of 0") + } + + // Check values on the packet that are specific to a particular chunk type + for _, c := range p.chunks { + switch c.(type) { // nolint:gocritic + case *chunkInit: + // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. + // They MUST be the only chunks present in the SCTP packets that carry + // them. + if len(p.chunks) != 1 { + return errors.Errorf("init chunk must not be bundled with any other chunk") + } + + // A packet containing an INIT chunk MUST have a zero Verification + // Tag. + if p.verificationTag != 0 { + return errors.Errorf("init chunk expects a verification tag of 0 on the packet when out-of-the-blue") + } + } + } + + return nil +} + +func min16(a, b uint16) uint16 { + if a < b { + return a + } + return b +} + +func max32(a, b uint32) uint32 { + if a > b { + return a + } + return b +} + +func min32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +// setState atomically sets the state of the Association. +// The caller should hold the lock. +func (a *Association) setState(newState uint32) { + oldState := atomic.SwapUint32(&a.state, newState) + if newState != oldState { + a.log.Debugf("[%s] state change: '%s' => '%s'", + a.name, + getAssociationStateString(oldState), + getAssociationStateString(newState)) + } +} + +// getState atomically returns the state of the Association. +func (a *Association) getState() uint32 { + return atomic.LoadUint32(&a.state) +} + +// BytesSent returns the number of bytes sent +func (a *Association) BytesSent() uint64 { + return atomic.LoadUint64(&a.bytesSent) +} + +// BytesReceived returns the number of bytes received +func (a *Association) BytesReceived() uint64 { + return atomic.LoadUint64(&a.bytesReceived) +} + +func setSupportedExtensions(init *chunkInitCommon) { + // nolint:godox + // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 + // An implementation supporting this (Supported Extensions Parameter) + // extension MUST list the ASCONF, the ASCONF-ACK, and the AUTH chunks + // in its INIT and INIT-ACK parameters. + init.params = append(init.params, ¶mSupportedExtensions{ + ChunkTypes: []chunkType{ctReconfig, ctForwardTSN}, + }) +} + +// The caller should hold the lock. +func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { + state := a.getState() + a.log.Debugf("[%s] chunkInit received in state '%s'", a.name, getAssociationStateString(state)) + + // https://tools.ietf.org/html/rfc4960#section-5.2.1 + // Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST + // respond with an INIT ACK using the same parameters it sent in its + // original INIT chunk (including its Initiate Tag, unchanged). When + // responding, the endpoint MUST send the INIT ACK back to the same + // address that the original INIT (sent by this endpoint) was sent. + + if state != closed && state != cookieWait && state != cookieEchoed { + // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, + // COOKIE-WAIT, and SHUTDOWN-ACK-SENT + return nil, errors.Errorf("todo: handle Init when in state %s", getAssociationStateString(state)) + } + + // Should we be setting any of these permanently until we've ACKed further? + a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = i.initiateTag + a.sourcePort = p.destinationPort + a.destinationPort = p.sourcePort + + // 13.2 This is the last TSN received in sequence. This value + // is set initially by taking the peer's initial TSN, + // received in the INIT or INIT ACK chunk, and + // subtracting one from it. + a.peerLastTSN = i.initialTSN - 1 + + for _, param := range i.params { + switch v := param.(type) { // nolint:gocritic + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on init)\n", a.name) + a.useForwardTSN = true + } + } + } + } + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on init)\n", a.name) + } + + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + + initAck := &chunkInitAck{} + + initAck.initialTSN = a.myNextTSN + initAck.numOutboundStreams = a.myMaxNumOutboundStreams + initAck.numInboundStreams = a.myMaxNumInboundStreams + initAck.initiateTag = a.myVerificationTag + initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + + if a.myCookie == nil { + var err error + if a.myCookie, err = newRandomStateCookie(); err != nil { + return nil, err + } + } + + initAck.params = []param{a.myCookie} + + setSupportedExtensions(&initAck.chunkInitCommon) + + outbound.chunks = []chunk{initAck} + + return pack(outbound), nil +} + +// The caller should hold the lock. +func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { + state := a.getState() + a.log.Debugf("[%s] chunkInitAck received in state '%s'", a.name, getAssociationStateString(state)) + if state != cookieWait { + // RFC 4960 + // 5.2.3. Unexpected INIT ACK + // If an INIT ACK is received by an endpoint in any state other than the + // COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. + // An unexpected INIT ACK usually indicates the processing of an old or + // duplicated INIT chunk. + return nil + } + + a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = i.initiateTag + a.peerLastTSN = i.initialTSN - 1 + if a.sourcePort != p.destinationPort || + a.destinationPort != p.sourcePort { + a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) + return nil + } + + a.rwnd = i.advertisedReceiverWindowCredit + a.log.Debugf("[%s] initial rwnd=%d", a.name, a.rwnd) + + // RFC 4690 Sec 7.2.1 + // o The initial value of ssthresh MAY be arbitrarily high (for + // example, implementations MAY use the size of the receiver + // advertised window). + a.ssthresh = a.rwnd + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", + a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + + a.t1Init.stop() + a.storedInit = nil + + var cookieParam *paramStateCookie + for _, param := range i.params { + switch v := param.(type) { + case *paramStateCookie: + cookieParam = v + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on initAck)\n", a.name) + a.useForwardTSN = true + } + } + } + } + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on initAck)\n", a.name) + } + if cookieParam == nil { + return errors.Errorf("no cookie in InitAck") + } + + a.storedCookieEcho = &chunkCookieEcho{} + a.storedCookieEcho.cookie = cookieParam.cookie + + err := a.sendCookieEcho() + if err != nil { + a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) + } + + a.t1Cookie.start(a.rtoMgr.getRTO()) + a.setState(cookieEchoed) + return nil +} + +// The caller should hold the lock. +func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { + a.log.Tracef("[%s] chunkHeartbeat", a.name) + hbi, ok := c.params[0].(*paramHeartbeatInfo) + if !ok { + a.log.Warnf("[%s] failed to handle Heartbeat, no ParamHeartbeatInfo", a.name) + } + + return pack(&packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: []chunk{&chunkHeartbeatAck{ + params: []param{ + ¶mHeartbeatInfo{ + heartbeatInformation: hbi.heartbeatInformation, + }, + }, + }}, + }) +} + +// The caller should hold the lock. +func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { + state := a.getState() + a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) + switch state { + default: + return nil + case established: + if !bytes.Equal(a.myCookie.cookie, c.cookie) { + return nil + } + case closed, cookieWait, cookieEchoed: + if !bytes.Equal(a.myCookie.cookie, c.cookie) { + return nil + } + + a.t1Init.stop() + a.storedInit = nil + + a.t1Cookie.stop() + a.storedCookieEcho = nil + + a.setState(established) + a.handshakeCompletedCh <- nil + } + + p := &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: []chunk{&chunkCookieAck{}}, + } + return pack(p) +} + +// The caller should hold the lock. +func (a *Association) handleCookieAck() { + state := a.getState() + a.log.Debugf("[%s] COOKIE-ACK received in state '%s'", a.name, getAssociationStateString(state)) + if state != cookieEchoed { + // RFC 4960 + // 5.2.5. Handle Duplicate COOKIE-ACK. + // At any state other than COOKIE-ECHOED, an endpoint should silently + // discard a received COOKIE ACK chunk. + return + } + + a.t1Cookie.stop() + a.storedCookieEcho = nil + + a.setState(established) + a.handshakeCompletedCh <- nil +} + +// The caller should hold the lock. +func (a *Association) handleData(d *chunkPayloadData) []*packet { + a.log.Tracef("[%s] DATA: tsn=%d immediateSack=%v len=%d", + a.name, d.tsn, d.immediateSack, len(d.userData)) + a.stats.incDATAs() + + canPush := a.payloadQueue.canPush(d, a.peerLastTSN) + if canPush { + s := a.getOrCreateStream(d.streamIdentifier) + if s == nil { + // silentely discard the data. (sender will retry on T3-rtx timeout) + // see pion/sctp#30 + a.log.Debugf("discard %d", d.streamSequenceNumber) + return nil + } + + if a.getMyReceiverWindowCredit() > 0 { + // Pass the new chunk to stream level as soon as it arrives + a.payloadQueue.push(d, a.peerLastTSN) + s.handleData(d) + } else { + // Receive buffer is full + lastTSN, ok := a.payloadQueue.getLastTSNReceived() + if ok && sna32LT(d.tsn, lastTSN) { + a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) + a.payloadQueue.push(d, a.peerLastTSN) + s.handleData(d) + } else { + a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) + } + } + } + + return a.handlePeerLastTSNAndAcknowledgement(d.immediateSack) +} + +// A common routine for handleData and handleForwardTSN routines +// The caller should hold the lock. +func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { + var reply []*packet + + // Try to advance peerLastTSN + + // From RFC 3758 Sec 3.6: + // .. and then MUST further advance its cumulative TSN point locally + // if possible + // Meaning, if peerLastTSN+1 points to a chunk that is received, + // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. + for { + if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { + break + } + a.peerLastTSN++ + + for _, rstReq := range a.reconfigRequests { + resp := a.resetStreamsIfAny(rstReq) + if resp != nil { + a.log.Debugf("[%s] RESET RESPONSE: %+v", a.name, resp) + reply = append(reply, resp) + } + } + } + + hasPacketLoss := (a.payloadQueue.size() > 0) + if hasPacketLoss { + a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString(a.peerLastTSN)) + } + + if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay { + if a.ackState == ackStateIdle { + a.delayedAckTriggered = true + } else { + a.immediateAckTriggered = true + } + } else { + a.immediateAckTriggered = true + } + + return reply +} + +// The caller should hold the lock. +func (a *Association) getMyReceiverWindowCredit() uint32 { + var bytesQueued uint32 + for _, s := range a.streams { + bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) + } + + if bytesQueued >= a.maxReceiveBufferSize { + return 0 + } + return a.maxReceiveBufferSize - bytesQueued +} + +// OpenStream opens a stream +func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType PayloadProtocolIdentifier) (*Stream, error) { + a.lock.Lock() + defer a.lock.Unlock() + + if _, ok := a.streams[streamIdentifier]; ok { + return nil, errors.Errorf("there already exists a stream with identifier %d", streamIdentifier) + } + + s := a.createStream(streamIdentifier, false) + s.setDefaultPayloadType(defaultPayloadType) + + return s, nil +} + +// AcceptStream accepts a stream +func (a *Association) AcceptStream() (*Stream, error) { + s, ok := <-a.acceptCh + if !ok { + return nil, io.EOF // no more incoming streams + } + return s, nil +} + +// createStream creates a stream. The caller should hold the lock and check no stream exists for this id. +func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream { + s := &Stream{ + association: a, + streamIdentifier: streamIdentifier, + reassemblyQueue: newReassemblyQueue(streamIdentifier), + log: a.log, + name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), + } + + s.readNotifier = sync.NewCond(&s.lock) + + if accept { + select { + case a.acceptCh <- s: + a.streams[streamIdentifier] = s + a.log.Debugf("[%s] accepted a new stream (streamIdentifier: %d)", + a.name, streamIdentifier) + default: + a.log.Debugf("[%s] dropped a new stream (acceptCh size: %d)", + a.name, len(a.acceptCh)) + return nil + } + } else { + a.streams[streamIdentifier] = s + } + + return s +} + +// getOrCreateStream gets or creates a stream. The caller should hold the lock. +func (a *Association) getOrCreateStream(streamIdentifier uint16) *Stream { + if s, ok := a.streams[streamIdentifier]; ok { + return s + } + + return a.createStream(streamIdentifier, true) +} + +// The caller should hold the lock. +func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, uint32, error) { // nolint:gocognit + bytesAckedPerStream := map[uint16]int{} + + // New ack point, so pop all ACKed packets from inflightQueue + // We add 1 because the "currentAckPoint" has already been popped from the inflight queue + // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 + for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, d.cumulativeTSNAck); i++ { + c, ok := a.inflightQueue.pop(i) + if !ok { + return nil, 0, errors.Errorf("tsn %v unable to be popped from inflight queue", i) + } + + if !c.acked { + // RFC 4096 sec 6.3.2. Retransmission Timer Rules + // R3) Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the + // T3-rtx timer for that address with its current RTO (if there is + // still outstanding data on that address). + if i == a.cumulativeTSNAckPoint+1 { + // T3 timer needs to be reset. Stop it for now. + a.t3RTX.stop() + } + + nBytesAcked := len(c.userData) + + // Sum the number of bytes acknowledged per stream + if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { + bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked + } else { + bytesAckedPerStream[c.streamIdentifier] = nBytesAcked + } + + // RFC 4960 sec 6.3.1. RTO Calculation + // C4) When data is in flight and when allowed by rule C5 below, a new + // RTT measurement MUST be made each round trip. Furthermore, new + // RTT measurements SHOULD be made no more than once per round trip + // for a given destination transport address. + // C5) Karn's algorithm: RTT measurements MUST NOT be made using + // packets that were retransmitted (and thus for which it is + // ambiguous whether the reply was for the first instance of the + // chunk or for a later instance) + if c.nSent == 1 && sna32GTE(c.tsn, a.minTSN2MeasureRTT) { + a.minTSN2MeasureRTT = a.myNextTSN + rtt := time.Since(c.since).Seconds() * 1000.0 + srtt := a.rtoMgr.setNewRTT(rtt) + a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", + a.name, rtt, srtt, a.rtoMgr.getRTO()) + } + } + + if a.inFastRecovery && c.tsn == a.fastRecoverExitPoint { + a.log.Debugf("[%s] exit fast-recovery", a.name) + a.inFastRecovery = false + } + } + + htna := d.cumulativeTSNAck + + // Mark selectively acknowledged chunks as "acked" + for _, g := range d.gapAckBlocks { + for i := g.start; i <= g.end; i++ { + tsn := d.cumulativeTSNAck + uint32(i) + c, ok := a.inflightQueue.get(tsn) + if !ok { + return nil, 0, errors.Errorf("requested non-existent TSN %v", tsn) + } + + if !c.acked { + nBytesAcked := a.inflightQueue.markAsAcked(tsn) + + // Sum the number of bytes acknowledged per stream + if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { + bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked + } else { + bytesAckedPerStream[c.streamIdentifier] = nBytesAcked + } + + a.log.Tracef("[%s] tsn=%d has been sacked", a.name, c.tsn) + + if c.nSent == 1 { + a.minTSN2MeasureRTT = a.myNextTSN + rtt := time.Since(c.since).Seconds() * 1000.0 + srtt := a.rtoMgr.setNewRTT(rtt) + a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", + a.name, rtt, srtt, a.rtoMgr.getRTO()) + } + + if sna32LT(htna, tsn) { + htna = tsn + } + } + } + } + + return bytesAckedPerStream, htna, nil +} + +// The caller should hold the lock. +func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { + // RFC 4096, sec 6.3.2. Retransmission Timer Rules + // R2) Whenever all outstanding data sent to an address have been + // acknowledged, turn off the T3-rtx timer of that address. + if a.inflightQueue.size() == 0 { + a.log.Tracef("[%s] SACK: no more packet in-flight (pending=%d)", a.name, a.pendingQueue.size()) + a.t3RTX.stop() + } else { + a.log.Tracef("[%s] T3-rtx timer start (pt2)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + } + + // Update congestion control parameters + if a.cwnd <= a.ssthresh { + // RFC 4096, sec 7.2.1. Slow-Start + // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST + // use the slow-start algorithm to increase cwnd only if the current + // congestion window is being fully utilized, an incoming SACK + // advances the Cumulative TSN Ack Point, and the data sender is not + // in Fast Recovery. Only when these three conditions are met can + // the cwnd be increased; otherwise, the cwnd MUST not be increased. + // If these conditions are met, then cwnd MUST be increased by, at + // most, the lesser of 1) the total size of the previously + // outstanding DATA chunk(s) acknowledged, and 2) the destination's + // path MTU. + if !a.inFastRecovery && + a.pendingQueue.size() > 0 { + a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) // TCP way + // a.cwnd += min32(uint32(totalBytesAcked), a.mtu) // SCTP way (slow) + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", + a.name, a.cwnd, a.ssthresh, totalBytesAcked) + } else { + a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d", + a.name, a.cwnd, a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) + } + } else { + // RFC 4096, sec 7.2.2. Congestion Avoidance + // o Whenever cwnd is greater than ssthresh, upon each SACK arrival + // that advances the Cumulative TSN Ack Point, increase + // partial_bytes_acked by the total number of bytes of all new chunks + // acknowledged in that SACK including chunks acknowledged by the new + // Cumulative TSN Ack and by Gap Ack Blocks. + a.partialBytesAcked += uint32(totalBytesAcked) + + // o When partial_bytes_acked is equal to or greater than cwnd and + // before the arrival of the SACK the sender had cwnd or more bytes + // of data outstanding (i.e., before arrival of the SACK, flight size + // was greater than or equal to cwnd), increase cwnd by MTU, and + // reset partial_bytes_acked to (partial_bytes_acked - cwnd). + if a.partialBytesAcked >= a.cwnd && a.pendingQueue.size() > 0 { + a.partialBytesAcked -= a.cwnd + a.cwnd += a.mtu + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", + a.name, a.cwnd, a.ssthresh, totalBytesAcked) + } + } +} + +// The caller should hold the lock. +func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cumTSNAckPointAdvanced bool) error { + // HTNA algorithm - RFC 4960 Sec 7.2.4 + // Increment missIndicator of each chunks that the SACK reported missing + // when either of the following is met: + // a) Not in fast-recovery + // miss indications are incremented only for missing TSNs prior to the + // highest TSN newly acknowledged in the SACK. + // b) In fast-recovery AND the Cumulative TSN Ack Point advanced + // the miss indications are incremented for all TSNs reported missing + // in the SACK. + if !a.inFastRecovery || (a.inFastRecovery && cumTSNAckPointAdvanced) { + var maxTSN uint32 + if !a.inFastRecovery { + // a) increment only for missing TSNs prior to the HTNA + maxTSN = htna + } else { + // b) increment for all TSNs reported missing + maxTSN = cumTSNAckPoint + uint32(a.inflightQueue.size()) + 1 + } + + for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ { + c, ok := a.inflightQueue.get(tsn) + if !ok { + return errors.Errorf("requested non-existent TSN %v", tsn) + } + if !c.acked && !c.abandoned() && c.missIndicator < 3 { + c.missIndicator++ + if c.missIndicator == 3 { + if !a.inFastRecovery { + // 2) If not in Fast Recovery, adjust the ssthresh and cwnd of the + // destination address(es) to which the missing DATA chunks were + // last sent, according to the formula described in Section 7.2.3. + a.inFastRecovery = true + a.fastRecoverExitPoint = htna + a.ssthresh = max32(a.cwnd/2, 4*a.mtu) + a.cwnd = a.ssthresh + a.partialBytesAcked = 0 + a.willRetransmitFast = true + + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)", + a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + } + } + } + } + } + + if a.inFastRecovery && cumTSNAckPointAdvanced { + a.willRetransmitFast = true + } + + return nil +} + +// The caller should hold the lock. +func (a *Association) handleSack(d *chunkSelectiveAck) error { + a.log.Tracef("[%s] SACK: cumTSN=%d a_rwnd=%d", a.name, d.cumulativeTSNAck, d.advertisedReceiverWindowCredit) + state := a.getState() + if state != established { + return nil + } + + a.stats.incSACKs() + + if sna32GT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { + // RFC 4960 sec 6.2.1. Processing a Received SACK + // D) + // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack + // Point, then drop the SACK. Since Cumulative TSN Ack is + // monotonically increasing, a SACK whose Cumulative TSN Ack is + // less than the Cumulative TSN Ack Point indicates an out-of- + // order SACK. + + a.log.Debugf("[%s] SACK Cumulative ACK %v is older than ACK point %v", + a.name, + d.cumulativeTSNAck, + a.cumulativeTSNAckPoint) + + return nil + } + + // Process selective ack + bytesAckedPerStream, htna, err := a.processSelectiveAck(d) + if err != nil { + return err + } + + var totalBytesAcked int + for _, nBytesAcked := range bytesAckedPerStream { + totalBytesAcked += nBytesAcked + } + + cumTSNAckPointAdvanced := false + if sna32LT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { + a.log.Tracef("[%s] SACK: cumTSN advanced: %d -> %d", + a.name, + a.cumulativeTSNAckPoint, + d.cumulativeTSNAck) + + a.cumulativeTSNAckPoint = d.cumulativeTSNAck + cumTSNAckPointAdvanced = true + a.onCumulativeTSNAckPointAdvanced(totalBytesAcked) + } + + for si, nBytesAcked := range bytesAckedPerStream { + if s, ok := a.streams[si]; ok { + a.lock.Unlock() + s.onBufferReleased(nBytesAcked) + a.lock.Lock() + } + } + + // New rwnd value + // RFC 4960 sec 6.2.1. Processing a Received SACK + // D) + // ii) Set rwnd equal to the newly received a_rwnd minus the number + // of bytes still outstanding after processing the Cumulative + // TSN Ack and the Gap Ack Blocks. + + // bytes acked were already subtracted by markAsAcked() method + bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) + if bytesOutstanding >= d.advertisedReceiverWindowCredit { + a.rwnd = 0 + } else { + a.rwnd = d.advertisedReceiverWindowCredit - bytesOutstanding + } + + err = a.processFastRetransmission(d.cumulativeTSNAck, htna, cumTSNAckPointAdvanced) + if err != nil { + return err + } + + if a.useForwardTSN { + // RFC 3758 Sec 3.5 C1 + if sna32LT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.advancedPeerTSNAckPoint = a.cumulativeTSNAckPoint + } + + // RFC 3758 Sec 3.5 C2 + for i := a.advancedPeerTSNAckPoint + 1; ; i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + if !c.abandoned() { + break + } + a.advancedPeerTSNAckPoint = i + } + + // RFC 3758 Sec 3.5 C3 + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.willSendForwardTSN = true + } + a.awakeWriteLoop() + } + + if a.inflightQueue.size() > 0 { + // Start timer. (noop if already started) + a.log.Tracef("[%s] T3-rtx timer start (pt3)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + } + + if cumTSNAckPointAdvanced { + a.awakeWriteLoop() + } + + return nil +} + +// createForwardTSN generates ForwardTSN chunk. +// This method will be be called if useForwardTSN is set to false. +// The caller should hold the lock. +func (a *Association) createForwardTSN() *chunkForwardTSN { + // RFC 3758 Sec 3.5 C4 + streamMap := map[uint16]uint16{} // to report only once per SI + for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, a.advancedPeerTSNAckPoint); i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + + ssn, ok := streamMap[c.streamIdentifier] + if !ok { + streamMap[c.streamIdentifier] = c.streamSequenceNumber + } else if sna16LT(ssn, c.streamSequenceNumber) { + // to report only once with greatest SSN + streamMap[c.streamIdentifier] = c.streamSequenceNumber + } + } + + fwdtsn := &chunkForwardTSN{ + newCumulativeTSN: a.advancedPeerTSNAckPoint, + streams: []chunkForwardTSNStream{}, + } + + var streamStr string + for si, ssn := range streamMap { + streamStr += fmt.Sprintf("(si=%d ssn=%d)", si, ssn) + fwdtsn.streams = append(fwdtsn.streams, chunkForwardTSNStream{ + identifier: si, + sequence: ssn, + }) + } + a.log.Tracef("[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr) + + return fwdtsn +} + +// createPacket wraps chunks in a packet. +// The caller should hold the read lock. +func (a *Association) createPacket(cs []chunk) *packet { + return &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: cs, + } +} + +// The caller should hold the lock. +func (a *Association) handleReconfig(c *chunkReconfig) ([]*packet, error) { + a.log.Tracef("[%s] handleReconfig", a.name) + + pp := make([]*packet, 0) + + p, err := a.handleReconfigParam(c.paramA) + if err != nil { + return nil, err + } + if p != nil { + pp = append(pp, p) + } + + if c.paramB != nil { + p, err = a.handleReconfigParam(c.paramB) + if err != nil { + return nil, err + } + if p != nil { + pp = append(pp, p) + } + } + return pp, nil +} + +// The caller should hold the lock. +func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { + a.log.Tracef("[%s] FwdTSN: %s", a.name, c.String()) + + if !a.useForwardTSN { + a.log.Warn("[%s] received FwdTSN but not enabled") + // Return an error chunk + cerr := &chunkError{ + errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}}, + } + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + outbound.chunks = []chunk{cerr} + return []*packet{outbound} + } + + // From RFC 3758 Sec 3.6: + // Note, if the "New Cumulative TSN" value carried in the arrived + // FORWARD TSN chunk is found to be behind or at the current cumulative + // TSN point, the data receiver MUST treat this FORWARD TSN as out-of- + // date and MUST NOT update its Cumulative TSN. The receiver SHOULD + // send a SACK to its peer (the sender of the FORWARD TSN) since such a + // duplicate may indicate the previous SACK was lost in the network. + + a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d\n", + a.name, c.newCumulativeTSN, a.peerLastTSN) + if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) { + a.log.Tracef("[%s] sending ack on Forward TSN", a.name) + a.ackState = ackStateImmediate + a.ackTimer.stop() + a.awakeWriteLoop() + return nil + } + + // From RFC 3758 Sec 3.6: + // the receiver MUST perform the same TSN handling, including duplicate + // detection, gap detection, SACK generation, cumulative TSN + // advancement, etc. as defined in RFC 2960 [2]---with the following + // exceptions and additions. + + // When a FORWARD TSN chunk arrives, the data receiver MUST first update + // its cumulative TSN point to the value carried in the FORWARD TSN + // chunk, + + // Advance peerLastTSN + for sna32LT(a.peerLastTSN, c.newCumulativeTSN) { + a.payloadQueue.pop(a.peerLastTSN + 1) // may not exist + a.peerLastTSN++ + } + + // Report new peerLastTSN value and abandoned largest SSN value to + // corresponding streams so that the abandoned chunks can be removed + // from the reassemblyQueue. + for _, forwarded := range c.streams { + if s, ok := a.streams[forwarded.identifier]; ok { + s.handleForwardTSNForOrdered(forwarded.sequence) + } + } + + // TSN may be forewared for unordered chunks. ForwardTSN chunk does not + // report which stream identifier it skipped for unordered chunks. + // Therefore, we need to broadcast this event to all existing streams for + // unordered chunks. + // See https://github.com/pion/sctp/issues/106 + for _, s := range a.streams { + s.handleForwardTSNForUnordered(c.newCumulativeTSN) + } + + return a.handlePeerLastTSNAndAcknowledgement(false) +} + +func (a *Association) sendResetRequest(streamIdentifier uint16) error { + a.lock.Lock() + defer a.lock.Unlock() + + state := a.getState() + if state != established { + return errors.Errorf("sending reset packet in non-established state: state=%s", + getAssociationStateString(state)) + } + + // Create DATA chunk which only contains valid stream identifier with + // nil userData and use it as a EOS from the stream. + c := &chunkPayloadData{ + streamIdentifier: streamIdentifier, + beginningFragment: true, + endingFragment: true, + userData: nil, + } + + a.pendingQueue.push(c) + a.awakeWriteLoop() + return nil +} + +// The caller should hold the lock. +func (a *Association) handleReconfigParam(raw param) (*packet, error) { + switch p := raw.(type) { + case *paramOutgoingResetRequest: + a.reconfigRequests[p.reconfigRequestSequenceNumber] = p + resp := a.resetStreamsIfAny(p) + if resp != nil { + return resp, nil + } + return nil, nil + + case *paramReconfigResponse: + delete(a.reconfigs, p.reconfigResponseSequenceNumber) + if len(a.reconfigs) == 0 { + a.tReconfig.stop() + } + return nil, nil + default: + return nil, errors.Errorf("unexpected parameter type %T", p) + } +} + +// The caller should hold the lock. +func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { + result := reconfigResultSuccessPerformed + if sna32LTE(p.senderLastTSN, a.peerLastTSN) { + a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", + a.name, p.senderLastTSN, a.peerLastTSN) + for _, id := range p.streamIdentifiers { + s, ok := a.streams[id] + if !ok { + continue + } + a.unregisterStream(s, io.EOF) + } + delete(a.reconfigRequests, p.reconfigRequestSequenceNumber) + } else { + a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", + a.name, p.senderLastTSN, a.peerLastTSN) + result = reconfigResultInProgress + } + + return a.createPacket([]chunk{&chunkReconfig{ + paramA: ¶mReconfigResponse{ + reconfigResponseSequenceNumber: p.reconfigRequestSequenceNumber, + result: result, + }, + }}) +} + +// Move the chunk peeked with a.pendingQueue.peek() to the inflightQueue. +// The caller should hold the lock. +func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) { + if err := a.pendingQueue.pop(c); err != nil { + a.log.Errorf("[%s] failed to pop from pending queue: %s", a.name, err.Error()) + } + + // Mark all fragements are in-flight now + if c.endingFragment { + c.setAllInflight() + } + + // Assign TSN + c.tsn = a.generateNextTSN() + + c.since = time.Now() // use to calculate RTT and also for maxPacketLifeTime + c.nSent = 1 // being sent for the first time + + a.checkPartialReliabilityStatus(c) + + a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", + a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment) + + // Push it into the inflightQueue + a.inflightQueue.pushNoCheck(c) +} + +// popPendingDataChunksToSend pops chunks from the pending queues as many as +// the cwnd and rwnd allows to send. +// The caller should hold the lock. +func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint16) { + chunks := []*chunkPayloadData{} + var sisToReset []uint16 // stream identifieres to reset + + if a.pendingQueue.size() > 0 { + // RFC 4960 sec 6.1. Transmission of DATA Chunks + // A) At any given time, the data sender MUST NOT transmit new data to + // any destination transport address if its peer's rwnd indicates + // that the peer has no buffer space (i.e., rwnd is 0; see Section + // 6.2.1). However, regardless of the value of rwnd (including if it + // is 0), the data sender can always have one DATA chunk in flight to + // the receiver if allowed by cwnd (see rule B, below). + + for { + c := a.pendingQueue.peek() + if c == nil { + break // no more pending data + } + + dataLen := uint32(len(c.userData)) + if dataLen == 0 { + sisToReset = append(sisToReset, c.streamIdentifier) + err := a.pendingQueue.pop(c) + if err != nil { + a.log.Errorf("failed to pop from pending queue: %s", err.Error()) + } + continue + } + + if uint32(a.inflightQueue.getNumBytes())+dataLen > a.cwnd { + break // would exceeds cwnd + } + + if dataLen > a.rwnd { + break // no more rwnd + } + + a.rwnd -= dataLen + + a.movePendingDataChunkToInflightQueue(c) + chunks = append(chunks, c) + } + + // the data sender can always have one DATA chunk in flight to the receiver + if len(chunks) == 0 && a.inflightQueue.size() == 0 { + // Send zero window probe + c := a.pendingQueue.peek() + if c != nil { + a.movePendingDataChunkToInflightQueue(c) + chunks = append(chunks, c) + } + } + } + + return chunks, sisToReset +} + +// bundleDataChunksIntoPackets packs DATA chunks into packets. It tries to bundle +// DATA chunks into a packet so long as the resulting packet size does not exceed +// the path MTU. +// The caller should hold the lock. +func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []*packet { + packets := []*packet{} + chunksToSend := []chunk{} + bytesInPacket := int(commonHeaderSize) + + for _, c := range chunks { + // RFC 4960 sec 6.1. Transmission of DATA Chunks + // Multiple DATA chunks committed for transmission MAY be bundled in a + // single packet. Furthermore, DATA chunks being retransmitted MAY be + // bundled with new DATA chunks, as long as the resulting packet size + // does not exceed the path MTU. + if bytesInPacket+len(c.userData) > int(a.mtu) { + packets = append(packets, a.createPacket(chunksToSend)) + chunksToSend = []chunk{} + bytesInPacket = int(commonHeaderSize) + } + + chunksToSend = append(chunksToSend, c) + bytesInPacket += int(dataChunkHeaderSize) + len(c.userData) + } + + if len(chunksToSend) > 0 { + packets = append(packets, a.createPacket(chunksToSend)) + } + + return packets +} + +// sendPayloadData sends the data chunks. +func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error { + a.lock.Lock() + defer a.lock.Unlock() + + state := a.getState() + if state != established { + return errors.Errorf("sending payload data in non-established state: state=%s", + getAssociationStateString(state)) + } + + // Push the chunks into the pending queue first. + for _, c := range chunks { + a.pendingQueue.push(c) + } + + a.awakeWriteLoop() + return nil +} + +// The caller should hold the lock. +func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { + if !a.useForwardTSN { + return + } + + // draft-ietf-rtcweb-data-protocol-09.txt section 6 + // 6. Procedures + // All Data Channel Establishment Protocol messages MUST be sent using + // ordered delivery and reliable transmission. + // + if c.payloadType == PayloadTypeWebRTCDCEP { + return + } + + // PR-SCTP + if s, ok := a.streams[c.streamIdentifier]; ok { + s.lock.RLock() + if s.reliabilityType == ReliabilityTypeRexmit { + if c.nSent >= s.reliabilityValue { + c.setAbandoned(true) + a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", a.name, c.tsn, c.payloadType, c.nSent) + } + } else if s.reliabilityType == ReliabilityTypeTimed { + elapsed := int64(time.Since(c.since).Seconds() * 1000) + if elapsed >= int64(s.reliabilityValue) { + c.setAbandoned(true) + a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", a.name, c.tsn, c.payloadType, elapsed) + } + } + s.lock.RUnlock() + } else { + a.log.Errorf("[%s] stream %d not found)", a.name, c.streamIdentifier) + } +} + +// getDataPacketsToRetransmit is called when T3-rtx is timed out and retransmit outstanding data chunks +// that are not acked or abandoned yet. +// The caller should hold the lock. +func (a *Association) getDataPacketsToRetransmit() []*packet { + awnd := min32(a.cwnd, a.rwnd) + chunks := []*chunkPayloadData{} + var bytesToSend int + var done bool + + for i := 0; !done; i++ { + c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) + if !ok { + break // end of pending data + } + + if !c.retransmit { + continue + } + + if i == 0 && int(a.rwnd) < len(c.userData) { + // Send it as a zero window probe + done = true + } else if bytesToSend+len(c.userData) > int(awnd) { + break + } + + // reset the retransmit flag not to retransmit again before the next + // t3-rtx timer fires + c.retransmit = false + bytesToSend += len(c.userData) + + c.nSent++ + + a.checkPartialReliabilityStatus(c) + + a.log.Tracef("[%s] retransmitting tsn=%d ssn=%d sent=%d", a.name, c.tsn, c.streamSequenceNumber, c.nSent) + + chunks = append(chunks, c) + } + + return a.bundleDataChunksIntoPackets(chunks) +} + +// generateNextTSN returns the myNextTSN and increases it. The caller should hold the lock. +// The caller should hold the lock. +func (a *Association) generateNextTSN() uint32 { + tsn := a.myNextTSN + a.myNextTSN++ + return tsn +} + +// generateNextRSN returns the myNextRSN and increases it. The caller should hold the lock. +// The caller should hold the lock. +func (a *Association) generateNextRSN() uint32 { + rsn := a.myNextRSN + a.myNextRSN++ + return rsn +} + +func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { + sack := &chunkSelectiveAck{} + sack.cumulativeTSNAck = a.peerLastTSN + sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() + sack.duplicateTSN = a.payloadQueue.popDuplicates() + sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks(a.peerLastTSN) + return sack +} + +func pack(p *packet) []*packet { + return []*packet{p} +} + +func (a *Association) handleChunkStart() { + a.lock.Lock() + defer a.lock.Unlock() + + a.delayedAckTriggered = false + a.immediateAckTriggered = false +} + +func (a *Association) handleChunkEnd() { + a.lock.Lock() + defer a.lock.Unlock() + + if a.immediateAckTriggered { + // Send SACK now! + a.ackState = ackStateImmediate + a.ackTimer.stop() + a.awakeWriteLoop() + } else if a.delayedAckTriggered { + // Will send delayed ack in the next ack timeout + a.ackState = ackStateDelay + a.ackTimer.start() + } +} + +func (a *Association) handleChunk(p *packet, c chunk) error { + a.lock.Lock() + defer a.lock.Unlock() + + var packets []*packet + var err error + + if _, err = c.check(); err != nil { + a.log.Errorf("[ %s ] failed validating chunk: %s ", a.name, err) + return nil + } + + switch c := c.(type) { + case *chunkInit: + packets, err = a.handleInit(p, c) + + case *chunkInitAck: + err = a.handleInitAck(p, c) + + case *chunkAbort: + var errStr string + for _, e := range c.errorCauses { + errStr += fmt.Sprintf("(%s)", e) + } + return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr) + + case *chunkError: + var errStr string + for _, e := range c.errorCauses { + errStr += fmt.Sprintf("(%s)", e) + } + a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) + + case *chunkHeartbeat: + packets = a.handleHeartbeat(c) + + case *chunkCookieEcho: + packets = a.handleCookieEcho(c) + + case *chunkCookieAck: + a.handleCookieAck() + + case *chunkPayloadData: + packets = a.handleData(c) + + case *chunkSelectiveAck: + err = a.handleSack(c) + + case *chunkReconfig: + packets, err = a.handleReconfig(c) + + case *chunkForwardTSN: + packets = a.handleForwardTSN(c) + + default: + err = errors.Errorf("unhandled chunk type") + } + + // Log and return, the only condition that is fatal is a ABORT chunk + if err != nil { + a.log.Errorf("Failed to handle chunk: %v", err) + return nil + } + + if len(packets) > 0 { + a.controlQueue.pushAll(packets) + a.awakeWriteLoop() + } + + return nil +} + +func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { + a.lock.Lock() + defer a.lock.Unlock() + + if id == timerT1Init { + err := a.sendInit() + if err != nil { + a.log.Debugf("[%s] failed to retransmit init (nRtos=%d): %v", a.name, nRtos, err) + } + return + } + + if id == timerT1Cookie { + err := a.sendCookieEcho() + if err != nil { + a.log.Debugf("[%s] failed to retransmit cookie-echo (nRtos=%d): %v", a.name, nRtos, err) + } + return + } + + if id == timerT3RTX { + a.stats.incT3Timeouts() + + // RFC 4960 sec 6.3.3 + // E1) For the destination address for which the timer expires, adjust + // its ssthresh with rules defined in Section 7.2.3 and set the + // cwnd <- MTU. + // RFC 4960 sec 7.2.3 + // When the T3-rtx timer expires on an address, SCTP should perform slow + // start by: + // ssthresh = max(cwnd/2, 4*MTU) + // cwnd = 1*MTU + + a.ssthresh = max32(a.cwnd/2, 4*a.mtu) + a.cwnd = a.mtu + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)", + a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + + // RFC 3758 sec 3.5 + // A5) Any time the T3-rtx timer expires, on any destination, the sender + // SHOULD try to advance the "Advanced.Peer.Ack.Point" by following + // the procedures outlined in C2 - C5. + if a.useForwardTSN { + // RFC 3758 Sec 3.5 C2 + for i := a.advancedPeerTSNAckPoint + 1; ; i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + if !c.abandoned() { + break + } + a.advancedPeerTSNAckPoint = i + } + + // RFC 3758 Sec 3.5 C3 + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.willSendForwardTSN = true + } + } + + a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.cwnd, a.ssthresh) + + /* + a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint) + a.log.Debugf(" - cumulativeTSNAckPoint=%d", a.cumulativeTSNAckPoint) + a.inflightQueue.updateSortedKeys() + for i, tsn := range a.inflightQueue.sorted { + if c, ok := a.inflightQueue.get(tsn); ok { + a.log.Debugf(" - [%d] tsn=%d acked=%v abandoned=%v (%v,%v) len=%d", + i, c.tsn, c.acked, c.abandoned(), c.beginningFragment, c.endingFragment, len(c.userData)) + } + } + */ + + a.inflightQueue.markAllToRetrasmit() + a.awakeWriteLoop() + + return + } + + if id == timerReconfig { + a.willRetransmitReconfig = true + a.awakeWriteLoop() + } +} + +func (a *Association) onRetransmissionFailure(id int) { + a.lock.Lock() + defer a.lock.Unlock() + + if id == timerT1Init { + a.log.Errorf("[%s] retransmission failure: T1-init", a.name) + a.handshakeCompletedCh <- errors.Errorf("handshake failed (INIT ACK)") + return + } + + if id == timerT1Cookie { + a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name) + a.handshakeCompletedCh <- errors.Errorf("handshake failed (COOKIE ECHO)") + return + } + + if id == timerT3RTX { + // T3-rtx timer will not fail by design + // Justifications: + // * ICE would fail if the connectivity is lost + // * WebRTC spec is not clear how this incident should be reported to ULP + a.log.Errorf("[%s] retransmission failure: T3-rtx (DATA)", a.name) + return + } +} + +func (a *Association) onAckTimeout() { + a.lock.Lock() + defer a.lock.Unlock() + + a.log.Tracef("[%s] ack timed out (ackState: %d)", a.name, a.ackState) + a.stats.incAckTimeouts() + + a.ackState = ackStateImmediate + a.awakeWriteLoop() +} + +// bufferedAmount returns total amount (in bytes) of currently buffered user data. +// This is used only by testing. +func (a *Association) bufferedAmount() int { + a.lock.RLock() + defer a.lock.RUnlock() + + return a.pendingQueue.getNumBytes() + a.inflightQueue.getNumBytes() +} + +// MaxMessageSize returns the maximum message size you can send. +func (a *Association) MaxMessageSize() uint32 { + return atomic.LoadUint32(&a.maxMessageSize) +} + +// SetMaxMessageSize sets the maximum message size you can send. +func (a *Association) SetMaxMessageSize(maxMsgSize uint32) { + atomic.StoreUint32(&a.maxMessageSize, maxMsgSize) +} |