summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/dtls/v2/handshaker.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/dtls/v2/handshaker.go')
-rw-r--r--vendor/github.com/pion/dtls/v2/handshaker.go334
1 files changed, 334 insertions, 0 deletions
diff --git a/vendor/github.com/pion/dtls/v2/handshaker.go b/vendor/github.com/pion/dtls/v2/handshaker.go
new file mode 100644
index 0000000..0f5077e
--- /dev/null
+++ b/vendor/github.com/pion/dtls/v2/handshaker.go
@@ -0,0 +1,334 @@
+package dtls
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io"
+ "sync"
+ "time"
+
+ "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
+ "github.com/pion/dtls/v2/pkg/protocol/alert"
+ "github.com/pion/dtls/v2/pkg/protocol/handshake"
+ "github.com/pion/logging"
+)
+
+// [RFC6347 Section-4.2.4]
+// +-----------+
+// +---> | PREPARING | <--------------------+
+// | +-----------+ |
+// | | |
+// | | Buffer next flight |
+// | | |
+// | \|/ |
+// | +-----------+ |
+// | | SENDING |<------------------+ | Send
+// | +-----------+ | | HelloRequest
+// Receive | | | |
+// next | | Send flight | | or
+// flight | +--------+ | |
+// | | | Set retransmit timer | | Receive
+// | | \|/ | | HelloRequest
+// | | +-----------+ | | Send
+// +--)--| WAITING |-------------------+ | ClientHello
+// | | +-----------+ Timer expires | |
+// | | | | |
+// | | +------------------------+ |
+// Receive | | Send Read retransmit |
+// last | | last |
+// flight | | flight |
+// | | |
+// \|/\|/ |
+// +-----------+ |
+// | FINISHED | -------------------------------+
+// +-----------+
+// | /|\
+// | |
+// +---+
+// Read retransmit
+// Retransmit last flight
+
+type handshakeState uint8
+
+const (
+ handshakeErrored handshakeState = iota
+ handshakePreparing
+ handshakeSending
+ handshakeWaiting
+ handshakeFinished
+)
+
+func (s handshakeState) String() string {
+ switch s {
+ case handshakeErrored:
+ return "Errored"
+ case handshakePreparing:
+ return "Preparing"
+ case handshakeSending:
+ return "Sending"
+ case handshakeWaiting:
+ return "Waiting"
+ case handshakeFinished:
+ return "Finished"
+ default:
+ return "Unknown"
+ }
+}
+
+type handshakeFSM struct {
+ currentFlight flightVal
+ flights []*packet
+ retransmit bool
+ state *State
+ cache *handshakeCache
+ cfg *handshakeConfig
+ closed chan struct{}
+}
+
+type handshakeConfig struct {
+ localPSKCallback PSKCallback
+ localPSKIdentityHint []byte
+ localCipherSuites []CipherSuite // Available CipherSuites
+ localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
+ extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
+ localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
+ serverName string
+ clientAuth ClientAuthType // If we are a client should we request a client certificate
+ localCertificates []tls.Certificate
+ nameToCertificate map[string]*tls.Certificate
+ insecureSkipVerify bool
+ verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
+ rootCAs *x509.CertPool
+ clientCAs *x509.CertPool
+ retransmitInterval time.Duration
+ customCipherSuites func() []CipherSuite
+
+ onFlightState func(flightVal, handshakeState)
+ log logging.LeveledLogger
+ keyLogWriter io.Writer
+
+ initialEpoch uint16
+
+ mu sync.Mutex
+}
+
+type flightConn interface {
+ notify(ctx context.Context, level alert.Level, desc alert.Description) error
+ writePackets(context.Context, []*packet) error
+ recvHandshake() <-chan chan struct{}
+ setLocalEpoch(epoch uint16)
+ handleQueuedPackets(context.Context) error
+}
+
+func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
+ if c.keyLogWriter == nil {
+ return
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
+ if err != nil {
+ c.log.Debugf("failed to write key log file: %s", err)
+ }
+}
+
+func srvCliStr(isClient bool) string {
+ if isClient {
+ return "client"
+ }
+ return "server"
+}
+
+func newHandshakeFSM(
+ s *State, cache *handshakeCache, cfg *handshakeConfig,
+ initialFlight flightVal,
+) *handshakeFSM {
+ return &handshakeFSM{
+ currentFlight: initialFlight,
+ state: s,
+ cache: cache,
+ cfg: cfg,
+ closed: make(chan struct{}),
+ }
+}
+
+func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
+ state := initialState
+ defer func() {
+ close(s.closed)
+ }()
+ for {
+ s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
+ if s.cfg.onFlightState != nil {
+ s.cfg.onFlightState(s.currentFlight, state)
+ }
+ var err error
+ switch state {
+ case handshakePreparing:
+ state, err = s.prepare(ctx, c)
+ case handshakeSending:
+ state, err = s.send(ctx, c)
+ case handshakeWaiting:
+ state, err = s.wait(ctx, c)
+ case handshakeFinished:
+ state, err = s.finish(ctx, c)
+ default:
+ return errInvalidFSMTransition
+ }
+ if err != nil {
+ return err
+ }
+ }
+}
+
+func (s *handshakeFSM) Done() <-chan struct{} {
+ return s.closed
+}
+
+func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
+ s.flights = nil
+ // Prepare flights
+ var (
+ a *alert.Alert
+ err error
+ pkts []*packet
+ )
+ gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
+ if errFlight != nil {
+ err = errFlight
+ a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
+ } else {
+ pkts, a, err = gen(c, s.state, s.cache, s.cfg)
+ s.retransmit = retransmit
+ }
+ if a != nil {
+ if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
+ if err != nil {
+ err = alertErr
+ }
+ }
+ }
+ if err != nil {
+ return handshakeErrored, err
+ }
+
+ s.flights = pkts
+ epoch := s.cfg.initialEpoch
+ nextEpoch := epoch
+ for _, p := range s.flights {
+ p.record.Header.Epoch += epoch
+ if p.record.Header.Epoch > nextEpoch {
+ nextEpoch = p.record.Header.Epoch
+ }
+ if h, ok := p.record.Content.(*handshake.Handshake); ok {
+ h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
+ s.state.handshakeSendSequence++
+ }
+ }
+ if epoch != nextEpoch {
+ s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
+ c.setLocalEpoch(nextEpoch)
+ }
+ return handshakeSending, nil
+}
+
+func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
+ // Send flights
+ if err := c.writePackets(ctx, s.flights); err != nil {
+ return handshakeErrored, err
+ }
+
+ if s.currentFlight.isLastSendFlight() {
+ return handshakeFinished, nil
+ }
+ return handshakeWaiting, nil
+}
+
+func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
+ parse, errFlight := s.currentFlight.getFlightParser()
+ if errFlight != nil {
+ if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
+ if errFlight != nil {
+ return handshakeErrored, alertErr
+ }
+ }
+ return handshakeErrored, errFlight
+ }
+
+ retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
+ for {
+ select {
+ case done := <-c.recvHandshake():
+ nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
+ close(done)
+ if alert != nil {
+ if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
+ if err != nil {
+ err = alertErr
+ }
+ }
+ }
+ if err != nil {
+ return handshakeErrored, err
+ }
+ if nextFlight == 0 {
+ break
+ }
+ s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
+ if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
+ return handshakeFinished, nil
+ }
+ s.currentFlight = nextFlight
+ return handshakePreparing, nil
+
+ case <-retransmitTimer.C:
+ if !s.retransmit {
+ return handshakeWaiting, nil
+ }
+ return handshakeSending, nil
+ case <-ctx.Done():
+ return handshakeErrored, ctx.Err()
+ }
+ }
+}
+
+func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
+ parse, errFlight := s.currentFlight.getFlightParser()
+ if errFlight != nil {
+ if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
+ if errFlight != nil {
+ return handshakeErrored, alertErr
+ }
+ }
+ return handshakeErrored, errFlight
+ }
+
+ retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
+ select {
+ case done := <-c.recvHandshake():
+ nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
+ close(done)
+ if alert != nil {
+ if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
+ if err != nil {
+ err = alertErr
+ }
+ }
+ }
+ if err != nil {
+ return handshakeErrored, err
+ }
+ if nextFlight == 0 {
+ break
+ }
+ <-retransmitTimer.C
+ // Retransmit last flight
+ return handshakeSending, nil
+
+ case <-ctx.Done():
+ return handshakeErrored, ctx.Err()
+ }
+ return handshakeFinished, nil
+}