diff options
Diffstat (limited to 'vendor/github.com/pion/srtp/v2/session_srtcp.go')
-rw-r--r-- | vendor/github.com/pion/srtp/v2/session_srtcp.go | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/vendor/github.com/pion/srtp/v2/session_srtcp.go b/vendor/github.com/pion/srtp/v2/session_srtcp.go new file mode 100644 index 0000000..a5fb656 --- /dev/null +++ b/vendor/github.com/pion/srtp/v2/session_srtcp.go @@ -0,0 +1,180 @@ +package srtp + +import ( + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/rtcp" +) + +const defaultSessionSRTCPReplayProtectionWindow = 64 + +// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session +// SRTCP itself does not have a design like this, but it is common in most applications +// for local/remote to each have their own keying material. This provides those patterns +// instead of making everyone re-implement +type SessionSRTCP struct { + session + writeStream *WriteStreamSRTCP +} + +// NewSessionSRTCP creates a SRTCP session using conn as the underlying transport. +func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl + if config == nil { + return nil, errNoConfig + } else if conn == nil { + return nil, errNoConn + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + localOpts := append( + []ContextOption{}, + config.LocalOptions..., + ) + remoteOpts := append( + []ContextOption{ + // Default options + SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow), + }, + config.RemoteOptions..., + ) + + s := &SessionSRTCP{ + session: session{ + nextConn: conn, + localOptions: localOpts, + remoteOptions: remoteOpts, + readStreams: map[uint32]readStream{}, + newStream: make(chan readStream), + started: make(chan interface{}), + closed: make(chan interface{}), + bufferFactory: config.BufferFactory, + log: loggerFactory.NewLogger("srtp"), + }, + } + s.writeStream = &WriteStreamSRTCP{s} + + err := s.session.start( + config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, + config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, + config.Profile, + s, + ) + if err != nil { + return nil, err + } + return s, nil +} + +// OpenWriteStream returns the global write stream for the Session +func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) { + return s.writeStream, nil +} + +// OpenReadStream opens a read stream for the given SSRC, it can be used +// if you want a certain SSRC, but don't want to wait for AcceptStream +func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { + r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + + if readStream, ok := r.(*ReadStreamSRTCP); ok { + return readStream, nil + } + return nil, errFailedTypeAssertion +} + +// AcceptStream returns a stream to handle RTCP for a single SSRC +func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { + stream, ok := <-s.newStream + if !ok { + return nil, 0, errStreamAlreadyClosed + } + + readStream, ok := stream.(*ReadStreamSRTCP) + if !ok { + return nil, 0, errFailedTypeAssertion + } + + return readStream, stream.GetSSRC(), nil +} + +// Close ends the session +func (s *SessionSRTCP) Close() error { + return s.session.close() +} + +// Private + +func (s *SessionSRTCP) write(buf []byte) (int, error) { + if _, ok := <-s.session.started; ok { + return 0, errStartedChannelUsedIncorrectly + } + + s.session.localContextMutex.Lock() + encrypted, err := s.localContext.EncryptRTCP(nil, buf, nil) + s.session.localContextMutex.Unlock() + + if err != nil { + return 0, err + } + return s.session.nextConn.Write(encrypted) +} + +func (s *SessionSRTCP) setWriteDeadline(t time.Time) error { + return s.session.nextConn.SetWriteDeadline(t) +} + +// create a list of Destination SSRCs +// that's a superset of all Destinations in the slice. +func destinationSSRC(pkts []rtcp.Packet) []uint32 { + ssrcSet := make(map[uint32]struct{}) + for _, p := range pkts { + for _, ssrc := range p.DestinationSSRC() { + ssrcSet[ssrc] = struct{}{} + } + } + + out := make([]uint32, 0, len(ssrcSet)) + for ssrc := range ssrcSet { + out = append(out, ssrc) + } + + return out +} + +func (s *SessionSRTCP) decrypt(buf []byte) error { + decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) + if err != nil { + return err + } + + pkt, err := rtcp.Unmarshal(decrypted) + if err != nil { + return err + } + + for _, ssrc := range destinationSSRC(pkt) { + r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + if r == nil { + return nil // Session has been closed + } else if isNew { + s.session.newStream <- r // Notify AcceptStream + } + + readStream, ok := r.(*ReadStreamSRTCP) + if !ok { + return errFailedTypeAssertion + } + + _, err = readStream.write(decrypted) + if err != nil { + return err + } + } + + return nil +} |