diff options
Diffstat (limited to 'vendor/github.com/pion/stun/client.go')
-rw-r--r-- | vendor/github.com/pion/stun/client.go | 631 |
1 files changed, 631 insertions, 0 deletions
diff --git a/vendor/github.com/pion/stun/client.go b/vendor/github.com/pion/stun/client.go new file mode 100644 index 0000000..62a0b6e --- /dev/null +++ b/vendor/github.com/pion/stun/client.go @@ -0,0 +1,631 @@ +package stun + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// Dial connects to the address on the named network and then +// initializes Client on that connection, returning error if any. +func Dial(network, address string) (*Client, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return NewClient(conn) +} + +// ErrNoConnection means that ClientOptions.Connection is nil. +var ErrNoConnection = errors.New("no connection provided") + +// ClientOption sets some client option. +type ClientOption func(c *Client) + +// WithHandler sets client handler which is called if Agent emits the Event +// with TransactionID that is not currently registered by Client. +// Useful for handling Data indications from TURN server. +func WithHandler(h Handler) ClientOption { + return func(c *Client) { + c.handler = h + } +} + +// WithRTO sets client RTO as defined in STUN RFC. +func WithRTO(rto time.Duration) ClientOption { + return func(c *Client) { + c.rto = int64(rto) + } +} + +// WithClock sets Clock of client, the source of current time. +// Also clock is passed to default collector if set. +func WithClock(clock Clock) ClientOption { + return func(c *Client) { + c.clock = clock + } +} + +// WithTimeoutRate sets RTO timer minimum resolution. +func WithTimeoutRate(d time.Duration) ClientOption { + return func(c *Client) { + c.rtoRate = d + } +} + +// WithAgent sets client STUN agent. +// +// Defaults to agent implementation in current package, +// see agent.go. +func WithAgent(a ClientAgent) ClientOption { + return func(c *Client) { + c.a = a + } +} + +// WithCollector rests client timeout collector, the implementation +// of ticker which calls function on each tick. +func WithCollector(coll Collector) ClientOption { + return func(c *Client) { + c.collector = coll + } +} + +// WithNoConnClose prevents client from closing underlying connection when +// the Close() method is called. +var WithNoConnClose ClientOption = func(c *Client) { + c.closeConn = false +} + +// WithNoRetransmit disables retransmissions and sets RTO to +// defaultMaxAttempts * defaultRTO which will be effectively time out +// if not set. +// +// Useful for TCP connections where transport handles RTO. +func WithNoRetransmit(c *Client) { + c.maxAttempts = 0 + if c.rto == 0 { + c.rto = defaultMaxAttempts * int64(defaultRTO) + } +} + +const ( + defaultTimeoutRate = time.Millisecond * 5 + defaultRTO = time.Millisecond * 300 + defaultMaxAttempts = 7 +) + +// NewClient initializes new Client from provided options, +// starting internal goroutines and using default options fields +// if necessary. Call Close method after using Client to close conn and +// release resources. +// +// The conn will be closed on Close call. Use WithNoConnClose option to +// prevent that. +// +// Note that user should handle the protocol multiplexing, client does not +// provide any API for it, so if you need to read application data, wrap the +// connection with your (de-)multiplexer and pass the wrapper as conn. +func NewClient(conn Connection, options ...ClientOption) (*Client, error) { + c := &Client{ + close: make(chan struct{}), + c: conn, + clock: systemClock, + rto: int64(defaultRTO), + rtoRate: defaultTimeoutRate, + t: make(map[transactionID]*clientTransaction, 100), + maxAttempts: defaultMaxAttempts, + closeConn: true, + } + for _, o := range options { + o(c) + } + if c.c == nil { + return nil, ErrNoConnection + } + if c.a == nil { + c.a = NewAgent(nil) + } + if err := c.a.SetHandler(c.handleAgentCallback); err != nil { + return nil, err + } + if c.collector == nil { + c.collector = &tickerCollector{ + close: make(chan struct{}), + clock: c.clock, + } + } + if err := c.collector.Start(c.rtoRate, func(t time.Time) { + closedOrPanic(c.a.Collect(t)) + }); err != nil { + return nil, err + } + c.wg.Add(1) + go c.readUntilClosed() + runtime.SetFinalizer(c, clientFinalizer) + return c, nil +} + +func clientFinalizer(c *Client) { + if c == nil { + return + } + err := c.Close() + if err == ErrClientClosed { + return + } + if err == nil { + log.Println("client: called finalizer on non-closed client") // nolint + return + } + log.Println("client: called finalizer on non-closed client:", err) // nolint +} + +// Connection wraps Reader, Writer and Closer interfaces. +type Connection interface { + io.Reader + io.Writer + io.Closer +} + +// ClientAgent is Agent implementation that is used by Client to +// process transactions. +type ClientAgent interface { + Process(*Message) error + Close() error + Start(id [TransactionIDSize]byte, deadline time.Time) error + Stop(id [TransactionIDSize]byte) error + Collect(time.Time) error + SetHandler(h Handler) error +} + +// Client simulates "connection" to STUN server. +type Client struct { + rto int64 // time.Duration + a ClientAgent + c Connection + close chan struct{} + rtoRate time.Duration + maxAttempts int32 + closed bool + closeConn bool // should call c.Close() while closing + wg sync.WaitGroup + clock Clock + handler Handler + collector Collector + t map[transactionID]*clientTransaction + + // mux guards closed and t + mux sync.RWMutex +} + +// clientTransaction represents transaction in progress. +// If transaction is succeed or failed, f will be called +// provided by event. +// Concurrent access is invalid. +type clientTransaction struct { + id transactionID + attempt int32 + calls int32 + h Handler + start time.Time + rto time.Duration + raw []byte +} + +func (t *clientTransaction) handle(e Event) { + if atomic.AddInt32(&t.calls, 1) == 1 { + t.h(e) + } +} + +var clientTransactionPool = &sync.Pool{ + New: func() interface{} { + return &clientTransaction{ + raw: make([]byte, 1500), + } + }, +} + +func acquireClientTransaction() *clientTransaction { + return clientTransactionPool.Get().(*clientTransaction) +} + +func putClientTransaction(t *clientTransaction) { + t.raw = t.raw[:0] + t.start = time.Time{} + t.attempt = 0 + t.id = transactionID{} + clientTransactionPool.Put(t) +} + +func (t *clientTransaction) nextTimeout(now time.Time) time.Time { + return now.Add(time.Duration(t.attempt+1) * t.rto) +} + +// start registers transaction. +// +// Could return ErrClientClosed, ErrTransactionExists. +func (c *Client) start(t *clientTransaction) error { + c.mux.Lock() + defer c.mux.Unlock() + if c.closed { + return ErrClientClosed + } + _, exists := c.t[t.id] + if exists { + return ErrTransactionExists + } + c.t[t.id] = t + return nil +} + +// Clock abstracts the source of current time. +type Clock interface { + Now() time.Time +} + +type systemClockService struct{} + +func (systemClockService) Now() time.Time { return time.Now() } + +var systemClock = systemClockService{} + +// SetRTO sets current RTO value. +func (c *Client) SetRTO(rto time.Duration) { + atomic.StoreInt64(&c.rto, int64(rto)) +} + +// StopErr occurs when Client fails to stop transaction while +// processing error. +type StopErr struct { + Err error // value returned by Stop() + Cause error // error that caused Stop() call +} + +func (e StopErr) Error() string { + return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err)) +} + +// CloseErr indicates client close failure. +type CloseErr struct { + AgentErr error + ConnectionErr error +} + +func sprintErr(err error) string { + if err == nil { + return "<nil>" + } + return err.Error() +} + +func (c CloseErr) Error() string { + return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr)) +} + +func (c *Client) readUntilClosed() { + defer c.wg.Done() + m := new(Message) + m.Raw = make([]byte, 1024) + for { + select { + case <-c.close: + return + default: + } + _, err := m.ReadFrom(c.c) + if err == nil { + if pErr := c.a.Process(m); pErr == ErrAgentClosed { + return + } + } + } +} + +func closedOrPanic(err error) { + if err == nil || err == ErrAgentClosed { + return + } + panic(err) // nolint +} + +type tickerCollector struct { + close chan struct{} + wg sync.WaitGroup + clock Clock +} + +// Collector calls function f with constant rate. +// +// The simple Collector is ticker which calls function on each tick. +type Collector interface { + Start(rate time.Duration, f func(now time.Time)) error + Close() error +} + +func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error { + t := time.NewTicker(rate) + a.wg.Add(1) + go func() { + defer a.wg.Done() + for { + select { + case <-a.close: + t.Stop() + return + case <-t.C: + f(a.clock.Now()) + } + } + }() + return nil +} + +func (a *tickerCollector) Close() error { + close(a.close) + a.wg.Wait() + return nil +} + +// ErrClientClosed indicates that client is closed. +var ErrClientClosed = errors.New("client is closed") + +// Close stops internal connection and agent, returning CloseErr on error. +func (c *Client) Close() error { + if err := c.checkInit(); err != nil { + return err + } + c.mux.Lock() + if c.closed { + c.mux.Unlock() + return ErrClientClosed + } + c.closed = true + c.mux.Unlock() + if closeErr := c.collector.Close(); closeErr != nil { + return closeErr + } + var connErr error + agentErr := c.a.Close() + if c.closeConn { + connErr = c.c.Close() + } + close(c.close) + c.wg.Wait() + if agentErr == nil && connErr == nil { + return nil + } + return CloseErr{ + AgentErr: agentErr, + ConnectionErr: connErr, + } +} + +// Indicate sends indication m to server. Shorthand to Start call +// with zero deadline and callback. +func (c *Client) Indicate(m *Message) error { + return c.Start(m, nil) +} + +// callbackWaitHandler blocks on wait() call until callback is called. +type callbackWaitHandler struct { + handler Handler + callback func(event Event) + cond *sync.Cond + processed bool +} + +func (s *callbackWaitHandler) HandleEvent(e Event) { + s.cond.L.Lock() + if s.callback == nil { + panic("s.callback is nil") // nolint + } + s.callback(e) + s.processed = true + s.cond.Broadcast() + s.cond.L.Unlock() +} + +func (s *callbackWaitHandler) wait() { + s.cond.L.Lock() + for !s.processed { + s.cond.Wait() + } + s.processed = false + s.callback = nil + s.cond.L.Unlock() +} + +func (s *callbackWaitHandler) setCallback(f func(event Event)) { + if f == nil { + panic("f is nil") // nolint + } + s.cond.L.Lock() + s.callback = f + if s.handler == nil { + s.handler = s.HandleEvent + } + s.cond.L.Unlock() +} + +var callbackWaitHandlerPool = sync.Pool{ + New: func() interface{} { + return &callbackWaitHandler{ + cond: sync.NewCond(new(sync.Mutex)), + } + }, +} + +// ErrClientNotInitialized means that client connection or agent is nil. +var ErrClientNotInitialized = errors.New("client not initialized") + +func (c *Client) checkInit() error { + if c == nil || c.c == nil || c.a == nil || c.close == nil { + return ErrClientNotInitialized + } + return nil +} + +// Do is Start wrapper that waits until callback is called. If no callback +// provided, Indicate is called instead. +// +// Do has cpu overhead due to blocking, see BenchmarkClient_Do. +// Use Start method for less overhead. +func (c *Client) Do(m *Message, f func(Event)) error { + if err := c.checkInit(); err != nil { + return err + } + if f == nil { + return c.Indicate(m) + } + h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) + h.setCallback(f) + defer func() { + callbackWaitHandlerPool.Put(h) + }() + if err := c.Start(m, h.handler); err != nil { + return err + } + h.wait() + return nil +} + +func (c *Client) delete(id transactionID) { + c.mux.Lock() + if c.t != nil { + delete(c.t, id) + } + c.mux.Unlock() +} + +type buffer struct { + buf []byte +} + +var bufferPool = &sync.Pool{ + New: func() interface{} { + return &buffer{buf: make([]byte, 2048)} + }, +} + +func (c *Client) handleAgentCallback(e Event) { + c.mux.Lock() + if c.closed { + c.mux.Unlock() + return + } + t, found := c.t[e.TransactionID] + if found { + delete(c.t, t.id) + } + c.mux.Unlock() + if !found { + if c.handler != nil && e.Error != ErrTransactionStopped { + c.handler(e) + } + // Ignoring. + return + } + if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil { + // Transaction completed. + t.handle(e) + putClientTransaction(t) + return + } + // Doing re-transmission. + t.attempt++ + b := bufferPool.Get().(*buffer) + b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)] + defer bufferPool.Put(b) + var ( + now = c.clock.Now() + timeOut = t.nextTimeout(now) + id = t.id + ) + // Starting client transaction. + if startErr := c.start(t); startErr != nil { + c.delete(id) + e.Error = startErr + t.handle(e) + putClientTransaction(t) + return + } + // Starting agent transaction. + if startErr := c.a.Start(id, timeOut); startErr != nil { + c.delete(id) + e.Error = startErr + t.handle(e) + putClientTransaction(t) + return + } + // Writing message to connection again. + _, writeErr := c.c.Write(b.buf) + if writeErr != nil { + c.delete(id) + e.Error = writeErr + // Stopping agent transaction instead of waiting until it's deadline. + // This will call handleAgentCallback with "ErrTransactionStopped" error + // which will be ignored. + if stopErr := c.a.Stop(id); stopErr != nil { + // Failed to stop agent transaction. Wrapping the error in StopError. + e.Error = StopErr{ + Err: stopErr, + Cause: writeErr, + } + } + t.handle(e) + putClientTransaction(t) + return + } +} + +// Start starts transaction (if h set) and writes message to server, handler +// is called asynchronously. +func (c *Client) Start(m *Message, h Handler) error { + if err := c.checkInit(); err != nil { + return err + } + c.mux.RLock() + closed := c.closed + c.mux.RUnlock() + if closed { + return ErrClientClosed + } + if h != nil { + // Starting transaction only if h is set. Useful for indications. + t := acquireClientTransaction() + t.id = m.TransactionID + t.start = c.clock.Now() + t.h = h + t.rto = time.Duration(atomic.LoadInt64(&c.rto)) + t.attempt = 0 + t.raw = append(t.raw[:0], m.Raw...) + t.calls = 0 + d := t.nextTimeout(t.start) + if err := c.start(t); err != nil { + return err + } + if err := c.a.Start(m.TransactionID, d); err != nil { + return err + } + } + _, err := m.WriteTo(c.c) + if err != nil && h != nil { + c.delete(m.TransactionID) + // Stopping transaction instead of waiting until deadline. + if stopErr := c.a.Stop(m.TransactionID); stopErr != nil { + return StopErr{ + Err: stopErr, + Cause: err, + } + } + } + return err +} |