summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/dtls/v2/flight1handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/dtls/v2/flight1handler.go')
-rw-r--r--vendor/github.com/pion/dtls/v2/flight1handler.go112
1 files changed, 112 insertions, 0 deletions
diff --git a/vendor/github.com/pion/dtls/v2/flight1handler.go b/vendor/github.com/pion/dtls/v2/flight1handler.go
new file mode 100644
index 0000000..9229292
--- /dev/null
+++ b/vendor/github.com/pion/dtls/v2/flight1handler.go
@@ -0,0 +1,112 @@
+package dtls
+
+import (
+ "context"
+
+ "github.com/pion/dtls/v2/pkg/crypto/elliptic"
+ "github.com/pion/dtls/v2/pkg/protocol"
+ "github.com/pion/dtls/v2/pkg/protocol/alert"
+ "github.com/pion/dtls/v2/pkg/protocol/extension"
+ "github.com/pion/dtls/v2/pkg/protocol/handshake"
+ "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
+)
+
+func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
+ // HelloVerifyRequest can be skipped by the server,
+ // so allow ServerHello during flight1 also
+ seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
+ handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
+ handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
+ )
+ if !ok {
+ // No valid message received. Keep reading
+ return 0, nil, nil
+ }
+
+ if _, ok := msgs[handshake.TypeServerHello]; ok {
+ // Flight1 and flight2 were skipped.
+ // Parse as flight3.
+ return flight3Parse(ctx, c, state, cache, cfg)
+ }
+
+ if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
+ // DTLS 1.2 clients must not assume that the server will use the protocol version
+ // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
+ if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
+ return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
+ }
+ state.cookie = append([]byte{}, h.Cookie...)
+ state.handshakeRecvSequence = seq
+ return flight3, nil, nil
+ }
+
+ return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
+}
+
+func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
+ var zeroEpoch uint16
+ state.localEpoch.Store(zeroEpoch)
+ state.remoteEpoch.Store(zeroEpoch)
+ state.namedCurve = defaultNamedCurve
+ state.cookie = nil
+
+ if err := state.localRandom.Populate(); err != nil {
+ return nil, nil, err
+ }
+
+ extensions := []extension.Extension{
+ &extension.SupportedSignatureAlgorithms{
+ SignatureHashAlgorithms: cfg.localSignatureSchemes,
+ },
+ &extension.RenegotiationInfo{
+ RenegotiatedConnection: 0,
+ },
+ }
+ if cfg.localPSKCallback == nil {
+ extensions = append(extensions, []extension.Extension{
+ &extension.SupportedEllipticCurves{
+ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
+ },
+ &extension.SupportedPointFormats{
+ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
+ },
+ }...)
+ }
+
+ if len(cfg.localSRTPProtectionProfiles) > 0 {
+ extensions = append(extensions, &extension.UseSRTP{
+ ProtectionProfiles: cfg.localSRTPProtectionProfiles,
+ })
+ }
+
+ if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
+ cfg.extendedMasterSecret == RequireExtendedMasterSecret {
+ extensions = append(extensions, &extension.UseExtendedMasterSecret{
+ Supported: true,
+ })
+ }
+
+ if len(cfg.serverName) > 0 {
+ extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
+ }
+
+ return []*packet{
+ {
+ record: &recordlayer.RecordLayer{
+ Header: recordlayer.Header{
+ Version: protocol.Version1_2,
+ },
+ Content: &handshake.Handshake{
+ Message: &handshake.MessageClientHello{
+ Version: protocol.Version1_2,
+ Cookie: state.cookie,
+ Random: state.localRandom,
+ CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
+ CompressionMethods: defaultCompressionMethods(),
+ Extensions: extensions,
+ },
+ },
+ },
+ },
+ }, nil, nil
+}