summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/mdns/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/mdns/conn.go')
-rw-r--r--vendor/github.com/pion/mdns/conn.go316
1 files changed, 316 insertions, 0 deletions
diff --git a/vendor/github.com/pion/mdns/conn.go b/vendor/github.com/pion/mdns/conn.go
new file mode 100644
index 0000000..a8aafb2
--- /dev/null
+++ b/vendor/github.com/pion/mdns/conn.go
@@ -0,0 +1,316 @@
+package mdns
+
+import (
+ "context"
+ "math/big"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/pion/logging"
+ "golang.org/x/net/dns/dnsmessage"
+ "golang.org/x/net/ipv4"
+)
+
+// Conn represents a mDNS Server
+type Conn struct {
+ mu sync.RWMutex
+ log logging.LeveledLogger
+
+ socket *ipv4.PacketConn
+ dstAddr *net.UDPAddr
+
+ queryInterval time.Duration
+ localNames []string
+ queries []query
+
+ closed chan interface{}
+}
+
+type query struct {
+ nameWithSuffix string
+ queryResultChan chan queryResult
+}
+
+type queryResult struct {
+ answer dnsmessage.ResourceHeader
+ addr net.Addr
+}
+
+const (
+ inboundBufferSize = 512
+ defaultQueryInterval = time.Second
+ destinationAddress = "224.0.0.251:5353"
+ maxMessageRecords = 3
+ responseTTL = 120
+)
+
+// Server establishes a mDNS connection over an existing conn
+func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
+ if config == nil {
+ return nil, errNilConfig
+ }
+
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+
+ joinErrCount := 0
+ for i := range ifaces {
+ if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
+ joinErrCount++
+ }
+ }
+ if joinErrCount >= len(ifaces) {
+ return nil, errJoiningMulticastGroup
+ }
+
+ dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress)
+ if err != nil {
+ return nil, err
+
+ }
+
+ loggerFactory := config.LoggerFactory
+ if loggerFactory == nil {
+ loggerFactory = logging.NewDefaultLoggerFactory()
+ }
+
+ localNames := []string{}
+ for _, l := range config.LocalNames {
+ localNames = append(localNames, l+".")
+ }
+
+ c := &Conn{
+ queryInterval: defaultQueryInterval,
+ queries: []query{},
+ socket: conn,
+ dstAddr: dstAddr,
+ localNames: localNames,
+ log: loggerFactory.NewLogger("mdns"),
+ closed: make(chan interface{}),
+ }
+ if config.QueryInterval != 0 {
+ c.queryInterval = config.QueryInterval
+ }
+
+ go c.start()
+ return c, nil
+}
+
+// Close closes the mDNS Conn
+func (c *Conn) Close() error {
+ select {
+ case <-c.closed:
+ return nil
+ default:
+ }
+
+ if err := c.socket.Close(); err != nil {
+ return err
+ }
+
+ <-c.closed
+ return nil
+}
+
+// Query sends mDNS Queries for the following name until
+// either the Context is canceled/expires or we get a result
+func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {
+ select {
+ case <-c.closed:
+ return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
+ default:
+ }
+
+ nameWithSuffix := name + "."
+
+ queryChan := make(chan queryResult, 1)
+ c.mu.Lock()
+ c.queries = append(c.queries, query{nameWithSuffix, queryChan})
+ ticker := time.NewTicker(c.queryInterval)
+ c.mu.Unlock()
+
+ c.sendQuestion(nameWithSuffix)
+ for {
+ select {
+ case <-ticker.C:
+ c.sendQuestion(nameWithSuffix)
+ case <-c.closed:
+ return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
+ case res := <-queryChan:
+ return res.answer, res.addr, nil
+ case <-ctx.Done():
+ return dnsmessage.ResourceHeader{}, nil, errContextElapsed
+ }
+ }
+}
+
+func ipToBytes(ip net.IP) (out [4]byte) {
+ rawIP := ip.To4()
+ if rawIP == nil {
+ return
+ }
+
+ ipInt := big.NewInt(0)
+ ipInt.SetBytes(rawIP)
+ copy(out[:], ipInt.Bytes())
+ return
+}
+
+func interfaceForRemote(remote string) (net.IP, error) {
+ conn, err := net.Dial("udp", remote)
+ if err != nil {
+ return nil, err
+ }
+
+ localAddr := conn.LocalAddr().(*net.UDPAddr)
+ if err := conn.Close(); err != nil {
+ return nil, err
+ }
+
+ return localAddr.IP, nil
+}
+
+func (c *Conn) sendQuestion(name string) {
+ packedName, err := dnsmessage.NewName(name)
+ if err != nil {
+ c.log.Warnf("Failed to construct mDNS packet %v", err)
+ return
+ }
+
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{},
+ Questions: []dnsmessage.Question{
+ {
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Name: packedName,
+ },
+ },
+ }
+
+ rawQuery, err := msg.Pack()
+ if err != nil {
+ c.log.Warnf("Failed to construct mDNS packet %v", err)
+ return
+ }
+
+ if _, err := c.socket.WriteTo(rawQuery, nil, c.dstAddr); err != nil {
+ c.log.Warnf("Failed to send mDNS packet %v", err)
+ return
+ }
+}
+
+func (c *Conn) sendAnswer(name string, dst net.IP) {
+ packedName, err := dnsmessage.NewName(name)
+ if err != nil {
+ c.log.Warnf("Failed to construct mDNS packet %v", err)
+ return
+ }
+
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ Response: true,
+ Authoritative: true,
+ },
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Name: packedName,
+ TTL: responseTTL,
+ },
+ Body: &dnsmessage.AResource{
+ A: ipToBytes(dst),
+ },
+ },
+ },
+ }
+
+ rawAnswer, err := msg.Pack()
+ if err != nil {
+ c.log.Warnf("Failed to construct mDNS packet %v", err)
+ return
+ }
+
+ if _, err := c.socket.WriteTo(rawAnswer, nil, c.dstAddr); err != nil {
+ c.log.Warnf("Failed to send mDNS packet %v", err)
+ return
+ }
+}
+
+func (c *Conn) start() {
+ defer func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ close(c.closed)
+ }()
+
+ b := make([]byte, inboundBufferSize)
+ p := dnsmessage.Parser{}
+
+ for {
+ n, _, src, err := c.socket.ReadFrom(b)
+ if err != nil {
+ return
+ }
+
+ func() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if _, err := p.Start(b[:n]); err != nil {
+ c.log.Warnf("Failed to parse mDNS packet %v", err)
+ return
+ }
+
+ for i := 0; i <= maxMessageRecords; i++ {
+ q, err := p.Question()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ } else if err != nil {
+ c.log.Warnf("Failed to parse mDNS packet %v", err)
+ return
+ }
+
+ for _, localName := range c.localNames {
+ if localName == q.Name.String() {
+
+ localAddress, err := interfaceForRemote(src.String())
+ if err != nil {
+ c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
+ continue
+ }
+
+ c.sendAnswer(q.Name.String(), localAddress)
+ }
+ }
+ }
+
+ for i := 0; i <= maxMessageRecords; i++ {
+ a, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return
+ }
+ if err != nil {
+ c.log.Warnf("Failed to parse mDNS packet %v", err)
+ return
+ }
+
+ if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {
+ continue
+ }
+
+ for i := len(c.queries) - 1; i >= 0; i-- {
+ if c.queries[i].nameWithSuffix == a.Name.String() {
+ c.queries[i].queryResultChan <- queryResult{a, src}
+ c.queries = append(c.queries[:i], c.queries[i+1:]...)
+ }
+ }
+ }
+ }()
+ }
+}