summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/stun/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/stun/client.go')
-rw-r--r--vendor/github.com/pion/stun/client.go631
1 files changed, 631 insertions, 0 deletions
diff --git a/vendor/github.com/pion/stun/client.go b/vendor/github.com/pion/stun/client.go
new file mode 100644
index 0000000..62a0b6e
--- /dev/null
+++ b/vendor/github.com/pion/stun/client.go
@@ -0,0 +1,631 @@
+package stun
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Dial connects to the address on the named network and then
+// initializes Client on that connection, returning error if any.
+func Dial(network, address string) (*Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn)
+}
+
+// ErrNoConnection means that ClientOptions.Connection is nil.
+var ErrNoConnection = errors.New("no connection provided")
+
+// ClientOption sets some client option.
+type ClientOption func(c *Client)
+
+// WithHandler sets client handler which is called if Agent emits the Event
+// with TransactionID that is not currently registered by Client.
+// Useful for handling Data indications from TURN server.
+func WithHandler(h Handler) ClientOption {
+ return func(c *Client) {
+ c.handler = h
+ }
+}
+
+// WithRTO sets client RTO as defined in STUN RFC.
+func WithRTO(rto time.Duration) ClientOption {
+ return func(c *Client) {
+ c.rto = int64(rto)
+ }
+}
+
+// WithClock sets Clock of client, the source of current time.
+// Also clock is passed to default collector if set.
+func WithClock(clock Clock) ClientOption {
+ return func(c *Client) {
+ c.clock = clock
+ }
+}
+
+// WithTimeoutRate sets RTO timer minimum resolution.
+func WithTimeoutRate(d time.Duration) ClientOption {
+ return func(c *Client) {
+ c.rtoRate = d
+ }
+}
+
+// WithAgent sets client STUN agent.
+//
+// Defaults to agent implementation in current package,
+// see agent.go.
+func WithAgent(a ClientAgent) ClientOption {
+ return func(c *Client) {
+ c.a = a
+ }
+}
+
+// WithCollector rests client timeout collector, the implementation
+// of ticker which calls function on each tick.
+func WithCollector(coll Collector) ClientOption {
+ return func(c *Client) {
+ c.collector = coll
+ }
+}
+
+// WithNoConnClose prevents client from closing underlying connection when
+// the Close() method is called.
+var WithNoConnClose ClientOption = func(c *Client) {
+ c.closeConn = false
+}
+
+// WithNoRetransmit disables retransmissions and sets RTO to
+// defaultMaxAttempts * defaultRTO which will be effectively time out
+// if not set.
+//
+// Useful for TCP connections where transport handles RTO.
+func WithNoRetransmit(c *Client) {
+ c.maxAttempts = 0
+ if c.rto == 0 {
+ c.rto = defaultMaxAttempts * int64(defaultRTO)
+ }
+}
+
+const (
+ defaultTimeoutRate = time.Millisecond * 5
+ defaultRTO = time.Millisecond * 300
+ defaultMaxAttempts = 7
+)
+
+// NewClient initializes new Client from provided options,
+// starting internal goroutines and using default options fields
+// if necessary. Call Close method after using Client to close conn and
+// release resources.
+//
+// The conn will be closed on Close call. Use WithNoConnClose option to
+// prevent that.
+//
+// Note that user should handle the protocol multiplexing, client does not
+// provide any API for it, so if you need to read application data, wrap the
+// connection with your (de-)multiplexer and pass the wrapper as conn.
+func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
+ c := &Client{
+ close: make(chan struct{}),
+ c: conn,
+ clock: systemClock,
+ rto: int64(defaultRTO),
+ rtoRate: defaultTimeoutRate,
+ t: make(map[transactionID]*clientTransaction, 100),
+ maxAttempts: defaultMaxAttempts,
+ closeConn: true,
+ }
+ for _, o := range options {
+ o(c)
+ }
+ if c.c == nil {
+ return nil, ErrNoConnection
+ }
+ if c.a == nil {
+ c.a = NewAgent(nil)
+ }
+ if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
+ return nil, err
+ }
+ if c.collector == nil {
+ c.collector = &tickerCollector{
+ close: make(chan struct{}),
+ clock: c.clock,
+ }
+ }
+ if err := c.collector.Start(c.rtoRate, func(t time.Time) {
+ closedOrPanic(c.a.Collect(t))
+ }); err != nil {
+ return nil, err
+ }
+ c.wg.Add(1)
+ go c.readUntilClosed()
+ runtime.SetFinalizer(c, clientFinalizer)
+ return c, nil
+}
+
+func clientFinalizer(c *Client) {
+ if c == nil {
+ return
+ }
+ err := c.Close()
+ if err == ErrClientClosed {
+ return
+ }
+ if err == nil {
+ log.Println("client: called finalizer on non-closed client") // nolint
+ return
+ }
+ log.Println("client: called finalizer on non-closed client:", err) // nolint
+}
+
+// Connection wraps Reader, Writer and Closer interfaces.
+type Connection interface {
+ io.Reader
+ io.Writer
+ io.Closer
+}
+
+// ClientAgent is Agent implementation that is used by Client to
+// process transactions.
+type ClientAgent interface {
+ Process(*Message) error
+ Close() error
+ Start(id [TransactionIDSize]byte, deadline time.Time) error
+ Stop(id [TransactionIDSize]byte) error
+ Collect(time.Time) error
+ SetHandler(h Handler) error
+}
+
+// Client simulates "connection" to STUN server.
+type Client struct {
+ rto int64 // time.Duration
+ a ClientAgent
+ c Connection
+ close chan struct{}
+ rtoRate time.Duration
+ maxAttempts int32
+ closed bool
+ closeConn bool // should call c.Close() while closing
+ wg sync.WaitGroup
+ clock Clock
+ handler Handler
+ collector Collector
+ t map[transactionID]*clientTransaction
+
+ // mux guards closed and t
+ mux sync.RWMutex
+}
+
+// clientTransaction represents transaction in progress.
+// If transaction is succeed or failed, f will be called
+// provided by event.
+// Concurrent access is invalid.
+type clientTransaction struct {
+ id transactionID
+ attempt int32
+ calls int32
+ h Handler
+ start time.Time
+ rto time.Duration
+ raw []byte
+}
+
+func (t *clientTransaction) handle(e Event) {
+ if atomic.AddInt32(&t.calls, 1) == 1 {
+ t.h(e)
+ }
+}
+
+var clientTransactionPool = &sync.Pool{
+ New: func() interface{} {
+ return &clientTransaction{
+ raw: make([]byte, 1500),
+ }
+ },
+}
+
+func acquireClientTransaction() *clientTransaction {
+ return clientTransactionPool.Get().(*clientTransaction)
+}
+
+func putClientTransaction(t *clientTransaction) {
+ t.raw = t.raw[:0]
+ t.start = time.Time{}
+ t.attempt = 0
+ t.id = transactionID{}
+ clientTransactionPool.Put(t)
+}
+
+func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
+ return now.Add(time.Duration(t.attempt+1) * t.rto)
+}
+
+// start registers transaction.
+//
+// Could return ErrClientClosed, ErrTransactionExists.
+func (c *Client) start(t *clientTransaction) error {
+ c.mux.Lock()
+ defer c.mux.Unlock()
+ if c.closed {
+ return ErrClientClosed
+ }
+ _, exists := c.t[t.id]
+ if exists {
+ return ErrTransactionExists
+ }
+ c.t[t.id] = t
+ return nil
+}
+
+// Clock abstracts the source of current time.
+type Clock interface {
+ Now() time.Time
+}
+
+type systemClockService struct{}
+
+func (systemClockService) Now() time.Time { return time.Now() }
+
+var systemClock = systemClockService{}
+
+// SetRTO sets current RTO value.
+func (c *Client) SetRTO(rto time.Duration) {
+ atomic.StoreInt64(&c.rto, int64(rto))
+}
+
+// StopErr occurs when Client fails to stop transaction while
+// processing error.
+type StopErr struct {
+ Err error // value returned by Stop()
+ Cause error // error that caused Stop() call
+}
+
+func (e StopErr) Error() string {
+ return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
+}
+
+// CloseErr indicates client close failure.
+type CloseErr struct {
+ AgentErr error
+ ConnectionErr error
+}
+
+func sprintErr(err error) string {
+ if err == nil {
+ return "<nil>"
+ }
+ return err.Error()
+}
+
+func (c CloseErr) Error() string {
+ return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
+}
+
+func (c *Client) readUntilClosed() {
+ defer c.wg.Done()
+ m := new(Message)
+ m.Raw = make([]byte, 1024)
+ for {
+ select {
+ case <-c.close:
+ return
+ default:
+ }
+ _, err := m.ReadFrom(c.c)
+ if err == nil {
+ if pErr := c.a.Process(m); pErr == ErrAgentClosed {
+ return
+ }
+ }
+ }
+}
+
+func closedOrPanic(err error) {
+ if err == nil || err == ErrAgentClosed {
+ return
+ }
+ panic(err) // nolint
+}
+
+type tickerCollector struct {
+ close chan struct{}
+ wg sync.WaitGroup
+ clock Clock
+}
+
+// Collector calls function f with constant rate.
+//
+// The simple Collector is ticker which calls function on each tick.
+type Collector interface {
+ Start(rate time.Duration, f func(now time.Time)) error
+ Close() error
+}
+
+func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
+ t := time.NewTicker(rate)
+ a.wg.Add(1)
+ go func() {
+ defer a.wg.Done()
+ for {
+ select {
+ case <-a.close:
+ t.Stop()
+ return
+ case <-t.C:
+ f(a.clock.Now())
+ }
+ }
+ }()
+ return nil
+}
+
+func (a *tickerCollector) Close() error {
+ close(a.close)
+ a.wg.Wait()
+ return nil
+}
+
+// ErrClientClosed indicates that client is closed.
+var ErrClientClosed = errors.New("client is closed")
+
+// Close stops internal connection and agent, returning CloseErr on error.
+func (c *Client) Close() error {
+ if err := c.checkInit(); err != nil {
+ return err
+ }
+ c.mux.Lock()
+ if c.closed {
+ c.mux.Unlock()
+ return ErrClientClosed
+ }
+ c.closed = true
+ c.mux.Unlock()
+ if closeErr := c.collector.Close(); closeErr != nil {
+ return closeErr
+ }
+ var connErr error
+ agentErr := c.a.Close()
+ if c.closeConn {
+ connErr = c.c.Close()
+ }
+ close(c.close)
+ c.wg.Wait()
+ if agentErr == nil && connErr == nil {
+ return nil
+ }
+ return CloseErr{
+ AgentErr: agentErr,
+ ConnectionErr: connErr,
+ }
+}
+
+// Indicate sends indication m to server. Shorthand to Start call
+// with zero deadline and callback.
+func (c *Client) Indicate(m *Message) error {
+ return c.Start(m, nil)
+}
+
+// callbackWaitHandler blocks on wait() call until callback is called.
+type callbackWaitHandler struct {
+ handler Handler
+ callback func(event Event)
+ cond *sync.Cond
+ processed bool
+}
+
+func (s *callbackWaitHandler) HandleEvent(e Event) {
+ s.cond.L.Lock()
+ if s.callback == nil {
+ panic("s.callback is nil") // nolint
+ }
+ s.callback(e)
+ s.processed = true
+ s.cond.Broadcast()
+ s.cond.L.Unlock()
+}
+
+func (s *callbackWaitHandler) wait() {
+ s.cond.L.Lock()
+ for !s.processed {
+ s.cond.Wait()
+ }
+ s.processed = false
+ s.callback = nil
+ s.cond.L.Unlock()
+}
+
+func (s *callbackWaitHandler) setCallback(f func(event Event)) {
+ if f == nil {
+ panic("f is nil") // nolint
+ }
+ s.cond.L.Lock()
+ s.callback = f
+ if s.handler == nil {
+ s.handler = s.HandleEvent
+ }
+ s.cond.L.Unlock()
+}
+
+var callbackWaitHandlerPool = sync.Pool{
+ New: func() interface{} {
+ return &callbackWaitHandler{
+ cond: sync.NewCond(new(sync.Mutex)),
+ }
+ },
+}
+
+// ErrClientNotInitialized means that client connection or agent is nil.
+var ErrClientNotInitialized = errors.New("client not initialized")
+
+func (c *Client) checkInit() error {
+ if c == nil || c.c == nil || c.a == nil || c.close == nil {
+ return ErrClientNotInitialized
+ }
+ return nil
+}
+
+// Do is Start wrapper that waits until callback is called. If no callback
+// provided, Indicate is called instead.
+//
+// Do has cpu overhead due to blocking, see BenchmarkClient_Do.
+// Use Start method for less overhead.
+func (c *Client) Do(m *Message, f func(Event)) error {
+ if err := c.checkInit(); err != nil {
+ return err
+ }
+ if f == nil {
+ return c.Indicate(m)
+ }
+ h := callbackWaitHandlerPool.Get().(*callbackWaitHandler)
+ h.setCallback(f)
+ defer func() {
+ callbackWaitHandlerPool.Put(h)
+ }()
+ if err := c.Start(m, h.handler); err != nil {
+ return err
+ }
+ h.wait()
+ return nil
+}
+
+func (c *Client) delete(id transactionID) {
+ c.mux.Lock()
+ if c.t != nil {
+ delete(c.t, id)
+ }
+ c.mux.Unlock()
+}
+
+type buffer struct {
+ buf []byte
+}
+
+var bufferPool = &sync.Pool{
+ New: func() interface{} {
+ return &buffer{buf: make([]byte, 2048)}
+ },
+}
+
+func (c *Client) handleAgentCallback(e Event) {
+ c.mux.Lock()
+ if c.closed {
+ c.mux.Unlock()
+ return
+ }
+ t, found := c.t[e.TransactionID]
+ if found {
+ delete(c.t, t.id)
+ }
+ c.mux.Unlock()
+ if !found {
+ if c.handler != nil && e.Error != ErrTransactionStopped {
+ c.handler(e)
+ }
+ // Ignoring.
+ return
+ }
+ if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
+ // Transaction completed.
+ t.handle(e)
+ putClientTransaction(t)
+ return
+ }
+ // Doing re-transmission.
+ t.attempt++
+ b := bufferPool.Get().(*buffer)
+ b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
+ defer bufferPool.Put(b)
+ var (
+ now = c.clock.Now()
+ timeOut = t.nextTimeout(now)
+ id = t.id
+ )
+ // Starting client transaction.
+ if startErr := c.start(t); startErr != nil {
+ c.delete(id)
+ e.Error = startErr
+ t.handle(e)
+ putClientTransaction(t)
+ return
+ }
+ // Starting agent transaction.
+ if startErr := c.a.Start(id, timeOut); startErr != nil {
+ c.delete(id)
+ e.Error = startErr
+ t.handle(e)
+ putClientTransaction(t)
+ return
+ }
+ // Writing message to connection again.
+ _, writeErr := c.c.Write(b.buf)
+ if writeErr != nil {
+ c.delete(id)
+ e.Error = writeErr
+ // Stopping agent transaction instead of waiting until it's deadline.
+ // This will call handleAgentCallback with "ErrTransactionStopped" error
+ // which will be ignored.
+ if stopErr := c.a.Stop(id); stopErr != nil {
+ // Failed to stop agent transaction. Wrapping the error in StopError.
+ e.Error = StopErr{
+ Err: stopErr,
+ Cause: writeErr,
+ }
+ }
+ t.handle(e)
+ putClientTransaction(t)
+ return
+ }
+}
+
+// Start starts transaction (if h set) and writes message to server, handler
+// is called asynchronously.
+func (c *Client) Start(m *Message, h Handler) error {
+ if err := c.checkInit(); err != nil {
+ return err
+ }
+ c.mux.RLock()
+ closed := c.closed
+ c.mux.RUnlock()
+ if closed {
+ return ErrClientClosed
+ }
+ if h != nil {
+ // Starting transaction only if h is set. Useful for indications.
+ t := acquireClientTransaction()
+ t.id = m.TransactionID
+ t.start = c.clock.Now()
+ t.h = h
+ t.rto = time.Duration(atomic.LoadInt64(&c.rto))
+ t.attempt = 0
+ t.raw = append(t.raw[:0], m.Raw...)
+ t.calls = 0
+ d := t.nextTimeout(t.start)
+ if err := c.start(t); err != nil {
+ return err
+ }
+ if err := c.a.Start(m.TransactionID, d); err != nil {
+ return err
+ }
+ }
+ _, err := m.WriteTo(c.c)
+ if err != nil && h != nil {
+ c.delete(m.TransactionID)
+ // Stopping transaction instead of waiting until deadline.
+ if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
+ return StopErr{
+ Err: stopErr,
+ Cause: err,
+ }
+ }
+ }
+ return err
+}