summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/srtp/v2/session_srtcp.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/srtp/v2/session_srtcp.go')
-rw-r--r--vendor/github.com/pion/srtp/v2/session_srtcp.go180
1 files changed, 180 insertions, 0 deletions
diff --git a/vendor/github.com/pion/srtp/v2/session_srtcp.go b/vendor/github.com/pion/srtp/v2/session_srtcp.go
new file mode 100644
index 0000000..a5fb656
--- /dev/null
+++ b/vendor/github.com/pion/srtp/v2/session_srtcp.go
@@ -0,0 +1,180 @@
+package srtp
+
+import (
+ "net"
+ "time"
+
+ "github.com/pion/logging"
+ "github.com/pion/rtcp"
+)
+
+const defaultSessionSRTCPReplayProtectionWindow = 64
+
+// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
+// SRTCP itself does not have a design like this, but it is common in most applications
+// for local/remote to each have their own keying material. This provides those patterns
+// instead of making everyone re-implement
+type SessionSRTCP struct {
+ session
+ writeStream *WriteStreamSRTCP
+}
+
+// NewSessionSRTCP creates a SRTCP session using conn as the underlying transport.
+func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl
+ if config == nil {
+ return nil, errNoConfig
+ } else if conn == nil {
+ return nil, errNoConn
+ }
+
+ loggerFactory := config.LoggerFactory
+ if loggerFactory == nil {
+ loggerFactory = logging.NewDefaultLoggerFactory()
+ }
+
+ localOpts := append(
+ []ContextOption{},
+ config.LocalOptions...,
+ )
+ remoteOpts := append(
+ []ContextOption{
+ // Default options
+ SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow),
+ },
+ config.RemoteOptions...,
+ )
+
+ s := &SessionSRTCP{
+ session: session{
+ nextConn: conn,
+ localOptions: localOpts,
+ remoteOptions: remoteOpts,
+ readStreams: map[uint32]readStream{},
+ newStream: make(chan readStream),
+ started: make(chan interface{}),
+ closed: make(chan interface{}),
+ bufferFactory: config.BufferFactory,
+ log: loggerFactory.NewLogger("srtp"),
+ },
+ }
+ s.writeStream = &WriteStreamSRTCP{s}
+
+ err := s.session.start(
+ config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
+ config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
+ config.Profile,
+ s,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+// OpenWriteStream returns the global write stream for the Session
+func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) {
+ return s.writeStream, nil
+}
+
+// OpenReadStream opens a read stream for the given SSRC, it can be used
+// if you want a certain SSRC, but don't want to wait for AcceptStream
+func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {
+ r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
+
+ if readStream, ok := r.(*ReadStreamSRTCP); ok {
+ return readStream, nil
+ }
+ return nil, errFailedTypeAssertion
+}
+
+// AcceptStream returns a stream to handle RTCP for a single SSRC
+func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
+ stream, ok := <-s.newStream
+ if !ok {
+ return nil, 0, errStreamAlreadyClosed
+ }
+
+ readStream, ok := stream.(*ReadStreamSRTCP)
+ if !ok {
+ return nil, 0, errFailedTypeAssertion
+ }
+
+ return readStream, stream.GetSSRC(), nil
+}
+
+// Close ends the session
+func (s *SessionSRTCP) Close() error {
+ return s.session.close()
+}
+
+// Private
+
+func (s *SessionSRTCP) write(buf []byte) (int, error) {
+ if _, ok := <-s.session.started; ok {
+ return 0, errStartedChannelUsedIncorrectly
+ }
+
+ s.session.localContextMutex.Lock()
+ encrypted, err := s.localContext.EncryptRTCP(nil, buf, nil)
+ s.session.localContextMutex.Unlock()
+
+ if err != nil {
+ return 0, err
+ }
+ return s.session.nextConn.Write(encrypted)
+}
+
+func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
+ return s.session.nextConn.SetWriteDeadline(t)
+}
+
+// create a list of Destination SSRCs
+// that's a superset of all Destinations in the slice.
+func destinationSSRC(pkts []rtcp.Packet) []uint32 {
+ ssrcSet := make(map[uint32]struct{})
+ for _, p := range pkts {
+ for _, ssrc := range p.DestinationSSRC() {
+ ssrcSet[ssrc] = struct{}{}
+ }
+ }
+
+ out := make([]uint32, 0, len(ssrcSet))
+ for ssrc := range ssrcSet {
+ out = append(out, ssrc)
+ }
+
+ return out
+}
+
+func (s *SessionSRTCP) decrypt(buf []byte) error {
+ decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
+ if err != nil {
+ return err
+ }
+
+ pkt, err := rtcp.Unmarshal(decrypted)
+ if err != nil {
+ return err
+ }
+
+ for _, ssrc := range destinationSSRC(pkt) {
+ r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
+ if r == nil {
+ return nil // Session has been closed
+ } else if isNew {
+ s.session.newStream <- r // Notify AcceptStream
+ }
+
+ readStream, ok := r.(*ReadStreamSRTCP)
+ if !ok {
+ return errFailedTypeAssertion
+ }
+
+ _, err = readStream.write(decrypted)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}