diff options
Diffstat (limited to 'vendor/github.com/pion/webrtc/v3/dtlstransport.go')
-rw-r--r-- | vendor/github.com/pion/webrtc/v3/dtlstransport.go | 430 |
1 files changed, 430 insertions, 0 deletions
diff --git a/vendor/github.com/pion/webrtc/v3/dtlstransport.go b/vendor/github.com/pion/webrtc/v3/dtlstransport.go new file mode 100644 index 0000000..cc25889 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v3/dtlstransport.go @@ -0,0 +1,430 @@ +// +build !js + +package webrtc + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/pkg/crypto/fingerprint" + "github.com/pion/logging" + "github.com/pion/srtp/v2" + "github.com/pion/webrtc/v3/internal/mux" + "github.com/pion/webrtc/v3/internal/util" + "github.com/pion/webrtc/v3/pkg/rtcerr" +) + +// DTLSTransport allows an application access to information about the DTLS +// transport over which RTP and RTCP packets are sent and received by +// RTPSender and RTPReceiver, as well other data such as SCTP packets sent +// and received by data channels. +type DTLSTransport struct { + lock sync.RWMutex + + iceTransport *ICETransport + certificates []Certificate + remoteParameters DTLSParameters + remoteCertificate []byte + state DTLSTransportState + srtpProtectionProfile srtp.ProtectionProfile + + onStateChangeHandler func(DTLSTransportState) + + conn *dtls.Conn + + srtpSession, srtcpSession atomic.Value + srtpEndpoint, srtcpEndpoint *mux.Endpoint + simulcastStreams []*srtp.ReadStreamSRTP + srtpReady chan struct{} + + dtlsMatcher mux.MatchFunc + + api *API + log logging.LeveledLogger +} + +// NewDTLSTransport creates a new DTLSTransport. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) { + t := &DTLSTransport{ + iceTransport: transport, + api: api, + state: DTLSTransportStateNew, + dtlsMatcher: mux.MatchDTLS, + srtpReady: make(chan struct{}), + log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"), + } + + if len(certificates) > 0 { + now := time.Now() + for _, x509Cert := range certificates { + if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) { + return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired} + } + t.certificates = append(t.certificates, x509Cert) + } + } else { + sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + certificate, err := GenerateCertificate(sk) + if err != nil { + return nil, err + } + t.certificates = []Certificate{*certificate} + } + + return t, nil +} + +// ICETransport returns the currently-configured *ICETransport or nil +// if one has not been configured +func (t *DTLSTransport) ICETransport() *ICETransport { + t.lock.RLock() + defer t.lock.RUnlock() + return t.iceTransport +} + +// onStateChange requires the caller holds the lock +func (t *DTLSTransport) onStateChange(state DTLSTransportState) { + t.state = state + handler := t.onStateChangeHandler + if handler != nil { + handler(state) + } +} + +// OnStateChange sets a handler that is fired when the DTLS +// connection state changes. +func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) { + t.lock.Lock() + defer t.lock.Unlock() + t.onStateChangeHandler = f +} + +// State returns the current dtls transport state. +func (t *DTLSTransport) State() DTLSTransportState { + t.lock.RLock() + defer t.lock.RUnlock() + return t.state +} + +// GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction. +func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) { + fingerprints := []DTLSFingerprint{} + + for _, c := range t.certificates { + prints, err := c.GetFingerprints() + if err != nil { + return DTLSParameters{}, err + } + + fingerprints = append(fingerprints, prints...) + } + + return DTLSParameters{ + Role: DTLSRoleAuto, // always returns the default role + Fingerprints: fingerprints, + }, nil +} + +// GetRemoteCertificate returns the certificate chain in use by the remote side +// returns an empty list prior to selection of the remote certificate +func (t *DTLSTransport) GetRemoteCertificate() []byte { + t.lock.RLock() + defer t.lock.RUnlock() + return t.remoteCertificate +} + +func (t *DTLSTransport) startSRTP() error { + srtpConfig := &srtp.Config{ + Profile: t.srtpProtectionProfile, + BufferFactory: t.api.settingEngine.BufferFactory, + LoggerFactory: t.api.settingEngine.LoggerFactory, + } + if t.api.settingEngine.replayProtection.SRTP != nil { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP), + ) + } + + if t.api.settingEngine.disableSRTPReplayProtection { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTPNoReplayProtection(), + ) + } + + if t.api.settingEngine.replayProtection.SRTCP != nil { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP), + ) + } + + if t.api.settingEngine.disableSRTCPReplayProtection { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTCPNoReplayProtection(), + ) + } + + connState := t.conn.ConnectionState() + err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient) + if err != nil { + return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) + } + + srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) + if err != nil { + return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) + } + + srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) + if err != nil { + return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) + } + + t.srtpSession.Store(srtpSession) + t.srtcpSession.Store(srtcpSession) + close(t.srtpReady) + return nil +} + +func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) { + if value := t.srtpSession.Load(); value != nil { + return value.(*srtp.SessionSRTP), nil + } + + return nil, errDtlsTransportNotStarted +} + +func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) { + if value := t.srtcpSession.Load(); value != nil { + return value.(*srtp.SessionSRTCP), nil + } + + return nil, errDtlsTransportNotStarted +} + +func (t *DTLSTransport) role() DTLSRole { + // If remote has an explicit role use the inverse + switch t.remoteParameters.Role { + case DTLSRoleClient: + return DTLSRoleServer + case DTLSRoleServer: + return DTLSRoleClient + default: + } + + // If SettingEngine has an explicit role + switch t.api.settingEngine.answeringDTLSRole { + case DTLSRoleServer: + return DTLSRoleServer + case DTLSRoleClient: + return DTLSRoleClient + default: + } + + // Remote was auto and no explicit role was configured via SettingEngine + if t.iceTransport.Role() == ICERoleControlling { + return DTLSRoleServer + } + return defaultDtlsRoleAnswer +} + +// Start DTLS transport negotiation with the parameters of the remote DTLS transport +func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { + // Take lock and prepare connection, we must not hold the lock + // when connecting + prepareTransport := func() (DTLSRole, *dtls.Config, error) { + t.lock.Lock() + defer t.lock.Unlock() + + if err := t.ensureICEConn(); err != nil { + return DTLSRole(0), nil, err + } + + if t.state != DTLSTransportStateNew { + return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)} + } + + t.srtpEndpoint = t.iceTransport.NewEndpoint(mux.MatchSRTP) + t.srtcpEndpoint = t.iceTransport.NewEndpoint(mux.MatchSRTCP) + t.remoteParameters = remoteParameters + + cert := t.certificates[0] + t.onStateChange(DTLSTransportStateConnecting) + + return t.role(), &dtls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{cert.x509Cert.Raw}, + PrivateKey: cert.privateKey, + }, + }, + SRTPProtectionProfiles: []dtls.SRTPProtectionProfile{dtls.SRTP_AEAD_AES_128_GCM, dtls.SRTP_AES128_CM_HMAC_SHA1_80}, + ClientAuth: dtls.RequireAnyClientCert, + LoggerFactory: t.api.settingEngine.LoggerFactory, + InsecureSkipVerify: true, + }, nil + } + + var dtlsConn *dtls.Conn + dtlsEndpoint := t.iceTransport.NewEndpoint(mux.MatchDTLS) + role, dtlsConfig, err := prepareTransport() + if err != nil { + return err + } + + if t.api.settingEngine.replayProtection.DTLS != nil { + dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS) + } + + // Connect as DTLS Client/Server, function is blocking and we + // must not hold the DTLSTransport lock + if role == DTLSRoleClient { + dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig) + } else { + dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig) + } + + // Re-take the lock, nothing beyond here is blocking + t.lock.Lock() + defer t.lock.Unlock() + + if err != nil { + t.onStateChange(DTLSTransportStateFailed) + return err + } + + srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile() + if !ok { + t.onStateChange(DTLSTransportStateFailed) + return ErrNoSRTPProtectionProfile + } + + switch srtpProfile { + case dtls.SRTP_AEAD_AES_128_GCM: + t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm + case dtls.SRTP_AES128_CM_HMAC_SHA1_80: + t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80 + default: + t.onStateChange(DTLSTransportStateFailed) + return ErrNoSRTPProtectionProfile + } + + if t.api.settingEngine.disableCertificateFingerprintVerification { + return nil + } + + // Check the fingerprint if a certificate was exchanged + remoteCerts := dtlsConn.ConnectionState().PeerCertificates + if len(remoteCerts) == 0 { + t.onStateChange(DTLSTransportStateFailed) + return errNoRemoteCertificate + } + t.remoteCertificate = remoteCerts[0] + + parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate) + if err != nil { + if closeErr := dtlsConn.Close(); closeErr != nil { + t.log.Error(err.Error()) + } + + t.onStateChange(DTLSTransportStateFailed) + return err + } + + if err = t.validateFingerPrint(parsedRemoteCert); err != nil { + if closeErr := dtlsConn.Close(); closeErr != nil { + t.log.Error(err.Error()) + } + + t.onStateChange(DTLSTransportStateFailed) + return err + } + + t.conn = dtlsConn + t.onStateChange(DTLSTransportStateConnected) + + return t.startSRTP() +} + +// Stop stops and closes the DTLSTransport object. +func (t *DTLSTransport) Stop() error { + t.lock.Lock() + defer t.lock.Unlock() + + // Try closing everything and collect the errors + var closeErrs []error + + if srtpSessionValue := t.srtpSession.Load(); srtpSessionValue != nil { + closeErrs = append(closeErrs, srtpSessionValue.(*srtp.SessionSRTP).Close()) + } + + if srtcpSessionValue := t.srtcpSession.Load(); srtcpSessionValue != nil { + closeErrs = append(closeErrs, srtcpSessionValue.(*srtp.SessionSRTCP).Close()) + } + + for i := range t.simulcastStreams { + closeErrs = append(closeErrs, t.simulcastStreams[i].Close()) + } + + if t.conn != nil { + // dtls connection may be closed on sctp close. + if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) { + closeErrs = append(closeErrs, err) + } + } + t.onStateChange(DTLSTransportStateClosed) + return util.FlattenErrs(closeErrs) +} + +func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error { + for _, fp := range t.remoteParameters.Fingerprints { + hashAlgo, err := fingerprint.HashFromString(fp.Algorithm) + if err != nil { + return err + } + + remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo) + if err != nil { + return err + } + + if strings.EqualFold(remoteValue, fp.Value) { + return nil + } + } + + return errNoMatchingCertificateFingerprint +} + +func (t *DTLSTransport) ensureICEConn() error { + if t.iceTransport == nil || t.iceTransport.State() == ICETransportStateNew { + return errICEConnectionNotStarted + } + + return nil +} + +func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) { + t.lock.Lock() + defer t.lock.Unlock() + + t.simulcastStreams = append(t.simulcastStreams, s) +} |