summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/dtls/v2/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/dtls/v2/conn.go')
-rw-r--r--vendor/github.com/pion/dtls/v2/conn.go978
1 files changed, 978 insertions, 0 deletions
diff --git a/vendor/github.com/pion/dtls/v2/conn.go b/vendor/github.com/pion/dtls/v2/conn.go
new file mode 100644
index 0000000..42d732f
--- /dev/null
+++ b/vendor/github.com/pion/dtls/v2/conn.go
@@ -0,0 +1,978 @@
+package dtls
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pion/dtls/v2/internal/closer"
+ "github.com/pion/dtls/v2/pkg/crypto/elliptic"
+ "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
+ "github.com/pion/dtls/v2/pkg/protocol"
+ "github.com/pion/dtls/v2/pkg/protocol/alert"
+ "github.com/pion/dtls/v2/pkg/protocol/handshake"
+ "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
+ "github.com/pion/logging"
+ "github.com/pion/transport/connctx"
+ "github.com/pion/transport/deadline"
+ "github.com/pion/transport/replaydetector"
+)
+
+const (
+ initialTickerInterval = time.Second
+ cookieLength = 20
+ defaultNamedCurve = elliptic.X25519
+ inboundBufferSize = 8192
+ // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
+ defaultReplayProtectionWindow = 64
+)
+
+func invalidKeyingLabels() map[string]bool {
+ return map[string]bool{
+ "client finished": true,
+ "server finished": true,
+ "master secret": true,
+ "key expansion": true,
+ }
+}
+
+// Conn represents a DTLS connection
+type Conn struct {
+ lock sync.RWMutex // Internal lock (must not be public)
+ nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
+ fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
+ handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
+ decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
+
+ state State // Internal state
+
+ maximumTransmissionUnit int
+
+ handshakeCompletedSuccessfully atomic.Value
+
+ encryptedPackets [][]byte
+
+ connectionClosedByUser bool
+ closeLock sync.Mutex
+ closed *closer.Closer
+ handshakeLoopsFinished sync.WaitGroup
+
+ readDeadline *deadline.Deadline
+ writeDeadline *deadline.Deadline
+
+ log logging.LeveledLogger
+
+ reading chan struct{}
+ handshakeRecv chan chan struct{}
+ cancelHandshaker func()
+ cancelHandshakeReader func()
+
+ fsm *handshakeFSM
+
+ replayProtectionWindow uint
+}
+
+func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
+ err := validateConfig(config)
+ if err != nil {
+ return nil, err
+ }
+
+ if nextConn == nil {
+ return nil, errNilNextConn
+ }
+
+ cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
+ if err != nil {
+ return nil, err
+ }
+
+ signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
+ if err != nil {
+ return nil, err
+ }
+
+ workerInterval := initialTickerInterval
+ if config.FlightInterval != 0 {
+ workerInterval = config.FlightInterval
+ }
+
+ loggerFactory := config.LoggerFactory
+ if loggerFactory == nil {
+ loggerFactory = logging.NewDefaultLoggerFactory()
+ }
+
+ logger := loggerFactory.NewLogger("dtls")
+
+ mtu := config.MTU
+ if mtu <= 0 {
+ mtu = defaultMTU
+ }
+
+ replayProtectionWindow := config.ReplayProtectionWindow
+ if replayProtectionWindow <= 0 {
+ replayProtectionWindow = defaultReplayProtectionWindow
+ }
+
+ c := &Conn{
+ nextConn: connctx.New(nextConn),
+ fragmentBuffer: newFragmentBuffer(),
+ handshakeCache: newHandshakeCache(),
+ maximumTransmissionUnit: mtu,
+
+ decrypted: make(chan interface{}, 1),
+ log: logger,
+
+ readDeadline: deadline.New(),
+ writeDeadline: deadline.New(),
+
+ reading: make(chan struct{}, 1),
+ handshakeRecv: make(chan chan struct{}),
+ closed: closer.NewCloser(),
+ cancelHandshaker: func() {},
+
+ replayProtectionWindow: uint(replayProtectionWindow),
+
+ state: State{
+ isClient: isClient,
+ },
+ }
+
+ c.setRemoteEpoch(0)
+ c.setLocalEpoch(0)
+
+ serverName := config.ServerName
+ // Use host from conn address when serverName is not provided
+ if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
+ remoteAddr := nextConn.RemoteAddr().String()
+ var host string
+ host, _, err = net.SplitHostPort(remoteAddr)
+ if err != nil {
+ serverName = remoteAddr
+ } else {
+ serverName = host
+ }
+ }
+
+ hsCfg := &handshakeConfig{
+ localPSKCallback: config.PSK,
+ localPSKIdentityHint: config.PSKIdentityHint,
+ localCipherSuites: cipherSuites,
+ localSignatureSchemes: signatureSchemes,
+ extendedMasterSecret: config.ExtendedMasterSecret,
+ localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
+ serverName: serverName,
+ clientAuth: config.ClientAuth,
+ localCertificates: config.Certificates,
+ insecureSkipVerify: config.InsecureSkipVerify,
+ verifyPeerCertificate: config.VerifyPeerCertificate,
+ rootCAs: config.RootCAs,
+ clientCAs: config.ClientCAs,
+ customCipherSuites: config.CustomCipherSuites,
+ retransmitInterval: workerInterval,
+ log: logger,
+ initialEpoch: 0,
+ keyLogWriter: config.KeyLogWriter,
+ }
+
+ var initialFlight flightVal
+ var initialFSMState handshakeState
+
+ if initialState != nil {
+ if c.state.isClient {
+ initialFlight = flight5
+ } else {
+ initialFlight = flight6
+ }
+ initialFSMState = handshakeFinished
+
+ c.state = *initialState
+ } else {
+ if c.state.isClient {
+ initialFlight = flight1
+ } else {
+ initialFlight = flight0
+ }
+ initialFSMState = handshakePreparing
+ }
+ // Do handshake
+ if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
+ return nil, err
+ }
+
+ c.log.Trace("Handshake Completed")
+
+ return c, nil
+}
+
+// Dial connects to the given network address and establishes a DTLS connection on top.
+// Connection handshake will timeout using ConnectContextMaker in the Config.
+// If you want to specify the timeout duration, use DialWithContext() instead.
+func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
+ ctx, cancel := config.connectContextMaker()
+ defer cancel()
+
+ return DialWithContext(ctx, network, raddr, config)
+}
+
+// Client establishes a DTLS connection over an existing connection.
+// Connection handshake will timeout using ConnectContextMaker in the Config.
+// If you want to specify the timeout duration, use ClientWithContext() instead.
+func Client(conn net.Conn, config *Config) (*Conn, error) {
+ ctx, cancel := config.connectContextMaker()
+ defer cancel()
+
+ return ClientWithContext(ctx, conn, config)
+}
+
+// Server listens for incoming DTLS connections.
+// Connection handshake will timeout using ConnectContextMaker in the Config.
+// If you want to specify the timeout duration, use ServerWithContext() instead.
+func Server(conn net.Conn, config *Config) (*Conn, error) {
+ ctx, cancel := config.connectContextMaker()
+ defer cancel()
+
+ return ServerWithContext(ctx, conn, config)
+}
+
+// DialWithContext connects to the given network address and establishes a DTLS connection on top.
+func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
+ pConn, err := net.DialUDP(network, nil, raddr)
+ if err != nil {
+ return nil, err
+ }
+ return ClientWithContext(ctx, pConn, config)
+}
+
+// ClientWithContext establishes a DTLS connection over an existing connection.
+func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
+ switch {
+ case config == nil:
+ return nil, errNoConfigProvided
+ case config.PSK != nil && config.PSKIdentityHint == nil:
+ return nil, errPSKAndIdentityMustBeSetForClient
+ }
+
+ return createConn(ctx, conn, config, true, nil)
+}
+
+// ServerWithContext listens for incoming DTLS connections.
+func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
+ if config == nil {
+ return nil, errNoConfigProvided
+ }
+
+ return createConn(ctx, conn, config, false, nil)
+}
+
+// Read reads data from the connection.
+func (c *Conn) Read(p []byte) (n int, err error) {
+ if !c.isHandshakeCompletedSuccessfully() {
+ return 0, errHandshakeInProgress
+ }
+
+ select {
+ case <-c.readDeadline.Done():
+ return 0, errDeadlineExceeded
+ default:
+ }
+
+ for {
+ select {
+ case <-c.readDeadline.Done():
+ return 0, errDeadlineExceeded
+ case out, ok := <-c.decrypted:
+ if !ok {
+ return 0, io.EOF
+ }
+ switch val := out.(type) {
+ case ([]byte):
+ if len(p) < len(val) {
+ return 0, errBufferTooSmall
+ }
+ copy(p, val)
+ return len(val), nil
+ case (error):
+ return 0, val
+ }
+ }
+ }
+}
+
+// Write writes len(p) bytes from p to the DTLS connection
+func (c *Conn) Write(p []byte) (int, error) {
+ if c.isConnectionClosed() {
+ return 0, ErrConnClosed
+ }
+
+ select {
+ case <-c.writeDeadline.Done():
+ return 0, errDeadlineExceeded
+ default:
+ }
+
+ if !c.isHandshakeCompletedSuccessfully() {
+ return 0, errHandshakeInProgress
+ }
+
+ return len(p), c.writePackets(c.writeDeadline, []*packet{
+ {
+ record: &recordlayer.RecordLayer{
+ Header: recordlayer.Header{
+ Epoch: c.getLocalEpoch(),
+ Version: protocol.Version1_2,
+ },
+ Content: &protocol.ApplicationData{
+ Data: p,
+ },
+ },
+ shouldEncrypt: true,
+ },
+ })
+}
+
+// Close closes the connection.
+func (c *Conn) Close() error {
+ err := c.close(true)
+ c.handshakeLoopsFinished.Wait()
+ return err
+}
+
+// ConnectionState returns basic DTLS details about the connection.
+// Note that this replaced the `Export` function of v1.
+func (c *Conn) ConnectionState() State {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return *c.state.clone()
+}
+
+// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
+func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+
+ if c.state.srtpProtectionProfile == 0 {
+ return 0, false
+ }
+
+ return c.state.srtpProtectionProfile, true
+}
+
+func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ var rawPackets [][]byte
+
+ for _, p := range pkts {
+ if h, ok := p.record.Content.(*handshake.Handshake); ok {
+ handshakeRaw, err := p.record.Marshal()
+ if err != nil {
+ return err
+ }
+
+ c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
+ srvCliStr(c.state.isClient), h.Header.Type.String(),
+ p.record.Header.Epoch, h.Header.MessageSequence)
+ c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
+
+ rawHandshakePackets, err := c.processHandshakePacket(p, h)
+ if err != nil {
+ return err
+ }
+ rawPackets = append(rawPackets, rawHandshakePackets...)
+ } else {
+ rawPacket, err := c.processPacket(p)
+ if err != nil {
+ return err
+ }
+ rawPackets = append(rawPackets, rawPacket)
+ }
+ }
+ if len(rawPackets) == 0 {
+ return nil
+ }
+ compactedRawPackets := c.compactRawPackets(rawPackets)
+
+ for _, compactedRawPackets := range compactedRawPackets {
+ if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
+ return netError(err)
+ }
+ }
+
+ return nil
+}
+
+func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
+ combinedRawPackets := make([][]byte, 0)
+ currentCombinedRawPacket := make([]byte, 0)
+
+ for _, rawPacket := range rawPackets {
+ if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
+ combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
+ currentCombinedRawPacket = []byte{}
+ }
+ currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
+ }
+
+ combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
+
+ return combinedRawPackets
+}
+
+func (c *Conn) processPacket(p *packet) ([]byte, error) {
+ epoch := p.record.Header.Epoch
+ for len(c.state.localSequenceNumber) <= int(epoch) {
+ c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
+ }
+ seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
+ if seq > recordlayer.MaxSequenceNumber {
+ // RFC 6347 Section 4.1.0
+ // The implementation must either abandon an association or rehandshake
+ // prior to allowing the sequence number to wrap.
+ return nil, errSequenceNumberOverflow
+ }
+ p.record.Header.SequenceNumber = seq
+
+ rawPacket, err := p.record.Marshal()
+ if err != nil {
+ return nil, err
+ }
+
+ if p.shouldEncrypt {
+ var err error
+ rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return rawPacket, nil
+}
+
+func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
+ rawPackets := make([][]byte, 0)
+
+ handshakeFragments, err := c.fragmentHandshake(h)
+ if err != nil {
+ return nil, err
+ }
+ epoch := p.record.Header.Epoch
+ for len(c.state.localSequenceNumber) <= int(epoch) {
+ c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
+ }
+
+ for _, handshakeFragment := range handshakeFragments {
+ seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
+ if seq > recordlayer.MaxSequenceNumber {
+ return nil, errSequenceNumberOverflow
+ }
+
+ recordlayerHeader := &recordlayer.Header{
+ Version: p.record.Header.Version,
+ ContentType: p.record.Header.ContentType,
+ ContentLen: uint16(len(handshakeFragment)),
+ Epoch: p.record.Header.Epoch,
+ SequenceNumber: seq,
+ }
+
+ recordlayerHeaderBytes, err := recordlayerHeader.Marshal()
+ if err != nil {
+ return nil, err
+ }
+
+ p.record.Header = *recordlayerHeader
+
+ rawPacket := append(recordlayerHeaderBytes, handshakeFragment...)
+ if p.shouldEncrypt {
+ var err error
+ rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ rawPackets = append(rawPackets, rawPacket)
+ }
+
+ return rawPackets, nil
+}
+
+func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
+ content, err := h.Message.Marshal()
+ if err != nil {
+ return nil, err
+ }
+
+ fragmentedHandshakes := make([][]byte, 0)
+
+ contentFragments := splitBytes(content, c.maximumTransmissionUnit)
+ if len(contentFragments) == 0 {
+ contentFragments = [][]byte{
+ {},
+ }
+ }
+
+ offset := 0
+ for _, contentFragment := range contentFragments {
+ contentFragmentLen := len(contentFragment)
+
+ headerFragment := &handshake.Header{
+ Type: h.Header.Type,
+ Length: h.Header.Length,
+ MessageSequence: h.Header.MessageSequence,
+ FragmentOffset: uint32(offset),
+ FragmentLength: uint32(contentFragmentLen),
+ }
+
+ offset += contentFragmentLen
+
+ headerFragmentRaw, err := headerFragment.Marshal()
+ if err != nil {
+ return nil, err
+ }
+
+ fragmentedHandshake := append(headerFragmentRaw, contentFragment...)
+ fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
+ }
+
+ return fragmentedHandshakes, nil
+}
+
+var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
+ New: func() interface{} {
+ b := make([]byte, inboundBufferSize)
+ return &b
+ },
+}
+
+func (c *Conn) readAndBuffer(ctx context.Context) error {
+ bufptr := poolReadBuffer.Get().(*[]byte)
+ defer poolReadBuffer.Put(bufptr)
+
+ b := *bufptr
+ i, err := c.nextConn.ReadContext(ctx, b)
+ if err != nil {
+ return netError(err)
+ }
+
+ pkts, err := recordlayer.UnpackDatagram(b[:i])
+ if err != nil {
+ return err
+ }
+
+ var hasHandshake bool
+ for _, p := range pkts {
+ hs, alert, err := c.handleIncomingPacket(p, true)
+ if alert != nil {
+ if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
+ if err == nil {
+ err = alertErr
+ }
+ }
+ }
+ if hs {
+ hasHandshake = true
+ }
+ switch e := err.(type) {
+ case nil:
+ case *errAlert:
+ if e.IsFatalOrCloseNotify() {
+ return e
+ }
+ default:
+ return e
+ }
+ }
+ if hasHandshake {
+ done := make(chan struct{})
+ select {
+ case c.handshakeRecv <- done:
+ // If the other party may retransmit the flight,
+ // we should respond even if it not a new message.
+ <-done
+ case <-c.fsm.Done():
+ }
+ }
+ return nil
+}
+
+func (c *Conn) handleQueuedPackets(ctx context.Context) error {
+ pkts := c.encryptedPackets
+ c.encryptedPackets = nil
+
+ for _, p := range pkts {
+ _, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
+ if alert != nil {
+ if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
+ if err == nil {
+ err = alertErr
+ }
+ }
+ }
+ switch e := err.(type) {
+ case nil:
+ case *errAlert:
+ if e.IsFatalOrCloseNotify() {
+ return e
+ }
+ default:
+ return e
+ }
+ }
+ return nil
+}
+
+func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
+ h := &recordlayer.Header{}
+ if err := h.Unmarshal(buf); err != nil {
+ // Decode error must be silently discarded
+ // [RFC6347 Section-4.1.2.7]
+ c.log.Debugf("discarded broken packet: %v", err)
+ return false, nil, nil
+ }
+
+ // Validate epoch
+ remoteEpoch := c.getRemoteEpoch()
+ if h.Epoch > remoteEpoch {
+ if h.Epoch > remoteEpoch+1 {
+ c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
+ h.Epoch, h.SequenceNumber,
+ )
+ return false, nil, nil
+ }
+ if enqueue {
+ c.log.Debug("received packet of next epoch, queuing packet")
+ c.encryptedPackets = append(c.encryptedPackets, buf)
+ }
+ return false, nil, nil
+ }
+
+ // Anti-replay protection
+ for len(c.state.replayDetector) <= int(h.Epoch) {
+ c.state.replayDetector = append(c.state.replayDetector,
+ replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
+ )
+ }
+ markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
+ if !ok {
+ c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
+ h.Epoch, h.SequenceNumber,
+ )
+ return false, nil, nil
+ }
+
+ // Decrypt
+ if h.Epoch != 0 {
+ if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
+ if enqueue {
+ c.encryptedPackets = append(c.encryptedPackets, buf)
+ c.log.Debug("handshake not finished, queuing packet")
+ }
+ return false, nil, nil
+ }
+
+ var err error
+ buf, err = c.state.cipherSuite.Decrypt(buf)
+ if err != nil {
+ c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
+ return false, nil, nil
+ }
+ }
+
+ isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
+ if err != nil {
+ // Decode error must be silently discarded
+ // [RFC6347 Section-4.1.2.7]
+ c.log.Debugf("defragment failed: %s", err)
+ return false, nil, nil
+ } else if isHandshake {
+ markPacketAsValid()
+ for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
+ rawHandshake := &handshake.Handshake{}
+ if err := rawHandshake.Unmarshal(out); err != nil {
+ c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
+ continue
+ }
+
+ _ = c.handshakeCache.push(out, epoch, rawHandshake.Header.MessageSequence, rawHandshake.Header.Type, !c.state.isClient)
+ }
+
+ return true, nil, nil
+ }
+
+ r := &recordlayer.RecordLayer{}
+ if err := r.Unmarshal(buf); err != nil {
+ return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
+ }
+
+ switch content := r.Content.(type) {
+ case *alert.Alert:
+ c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
+ var a *alert.Alert
+ if content.Description == alert.CloseNotify {
+ // Respond with a close_notify [RFC5246 Section 7.2.1]
+ a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
+ }
+ markPacketAsValid()
+ return false, a, &errAlert{content}
+ case *protocol.ChangeCipherSpec:
+ if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
+ if enqueue {
+ c.encryptedPackets = append(c.encryptedPackets, buf)
+ c.log.Debugf("CipherSuite not initialized, queuing packet")
+ }
+ return false, nil, nil
+ }
+
+ newRemoteEpoch := h.Epoch + 1
+ c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
+
+ if c.getRemoteEpoch()+1 == newRemoteEpoch {
+ c.setRemoteEpoch(newRemoteEpoch)
+ markPacketAsValid()
+ }
+ case *protocol.ApplicationData:
+ if h.Epoch == 0 {
+ return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
+ }
+
+ markPacketAsValid()
+
+ select {
+ case c.decrypted <- content.Data:
+ case <-c.closed.Done():
+ }
+
+ default:
+ return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
+ }
+ return false, nil, nil
+}
+
+func (c *Conn) recvHandshake() <-chan chan struct{} {
+ return c.handshakeRecv
+}
+
+func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
+ return c.writePackets(ctx, []*packet{
+ {
+ record: &recordlayer.RecordLayer{
+ Header: recordlayer.Header{
+ Epoch: c.getLocalEpoch(),
+ Version: protocol.Version1_2,
+ },
+ Content: &alert.Alert{
+ Level: level,
+ Description: desc,
+ },
+ },
+ shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
+ },
+ })
+}
+
+func (c *Conn) setHandshakeCompletedSuccessfully() {
+ c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
+}
+
+func (c *Conn) isHandshakeCompletedSuccessfully() bool {
+ boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
+ return boolean.bool
+}
+
+func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
+ c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
+
+ done := make(chan struct{})
+ ctxRead, cancelRead := context.WithCancel(context.Background())
+ c.cancelHandshakeReader = cancelRead
+ cfg.onFlightState = func(f flightVal, s handshakeState) {
+ if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
+ c.setHandshakeCompletedSuccessfully()
+ close(done)
+ }
+ }
+
+ ctxHs, cancel := context.WithCancel(context.Background())
+ c.cancelHandshaker = cancel
+
+ firstErr := make(chan error, 1)
+
+ c.handshakeLoopsFinished.Add(2)
+
+ // Handshake routine should be live until close.
+ // The other party may request retransmission of the last flight to cope with packet drop.
+ go func() {
+ defer c.handshakeLoopsFinished.Done()
+ err := c.fsm.Run(ctxHs, c, initialState)
+ if !errors.Is(err, context.Canceled) {
+ select {
+ case firstErr <- err:
+ default:
+ }
+ }
+ }()
+ go func() {
+ defer func() {
+ // Escaping read loop.
+ // It's safe to close decrypted channnel now.
+ close(c.decrypted)
+
+ // Force stop handshaker when the underlying connection is closed.
+ cancel()
+ }()
+ defer c.handshakeLoopsFinished.Done()
+ for {
+ if err := c.readAndBuffer(ctxRead); err != nil {
+ switch e := err.(type) {
+ case *errAlert:
+ if !e.IsFatalOrCloseNotify() {
+ if c.isHandshakeCompletedSuccessfully() {
+ // Pass the error to Read()
+ select {
+ case c.decrypted <- err:
+ case <-c.closed.Done():
+ }
+ }
+ continue // non-fatal alert must not stop read loop
+ }
+ case error:
+ switch err {
+ case context.DeadlineExceeded, context.Canceled, io.EOF:
+ default:
+ if c.isHandshakeCompletedSuccessfully() {
+ // Keep read loop and pass the read error to Read()
+ select {
+ case c.decrypted <- err:
+ case <-c.closed.Done():
+ }
+ continue // non-fatal alert must not stop read loop
+ }
+ }
+ }
+ select {
+ case firstErr <- err:
+ default:
+ }
+
+ if e, ok := err.(*errAlert); ok {
+ if e.IsFatalOrCloseNotify() {
+ _ = c.close(false)
+ }
+ }
+ return
+ }
+ }
+ }()
+
+ select {
+ case err := <-firstErr:
+ cancelRead()
+ cancel()
+ return c.translateHandshakeCtxError(err)
+ case <-ctx.Done():
+ cancelRead()
+ cancel()
+ return c.translateHandshakeCtxError(ctx.Err())
+ case <-done:
+ return nil
+ }
+}
+
+func (c *Conn) translateHandshakeCtxError(err error) error {
+ if err == nil {
+ return nil
+ }
+ if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
+ return nil
+ }
+ return &HandshakeError{Err: err}
+}
+
+func (c *Conn) close(byUser bool) error {
+ c.cancelHandshaker()
+ c.cancelHandshakeReader()
+
+ if c.isHandshakeCompletedSuccessfully() && byUser {
+ // Discard error from notify() to return non-error on the first user call of Close()
+ // even if the underlying connection is already closed.
+ _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
+ }
+
+ c.closeLock.Lock()
+ // Don't return ErrConnClosed at the first time of the call from user.
+ closedByUser := c.connectionClosedByUser
+ if byUser {
+ c.connectionClosedByUser = true
+ }
+ c.closed.Close()
+ c.closeLock.Unlock()
+
+ if closedByUser {
+ return ErrConnClosed
+ }
+
+ return c.nextConn.Close()
+}
+
+func (c *Conn) isConnectionClosed() bool {
+ select {
+ case <-c.closed.Done():
+ return true
+ default:
+ return false
+ }
+}
+
+func (c *Conn) setLocalEpoch(epoch uint16) {
+ c.state.localEpoch.Store(epoch)
+}
+
+func (c *Conn) getLocalEpoch() uint16 {
+ return c.state.localEpoch.Load().(uint16)
+}
+
+func (c *Conn) setRemoteEpoch(epoch uint16) {
+ c.state.remoteEpoch.Store(epoch)
+}
+
+func (c *Conn) getRemoteEpoch() uint16 {
+ return c.state.remoteEpoch.Load().(uint16)
+}
+
+// LocalAddr implements net.Conn.LocalAddr
+func (c *Conn) LocalAddr() net.Addr {
+ return c.nextConn.LocalAddr()
+}
+
+// RemoteAddr implements net.Conn.RemoteAddr
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.nextConn.RemoteAddr()
+}
+
+// SetDeadline implements net.Conn.SetDeadline
+func (c *Conn) SetDeadline(t time.Time) error {
+ c.readDeadline.Set(t)
+ return c.SetWriteDeadline(t)
+}
+
+// SetReadDeadline implements net.Conn.SetReadDeadline
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ c.readDeadline.Set(t)
+ // Read deadline is fully managed by this layer.
+ // Don't set read deadline to underlying connection.
+ return nil
+}
+
+// SetWriteDeadline implements net.Conn.SetWriteDeadline
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ c.writeDeadline.Set(t)
+ // Write deadline is also fully managed by this layer.
+ return nil
+}