summaryrefslogtreecommitdiff
path: root/vendor/github.com/pion/stun/message.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/pion/stun/message.go')
-rw-r--r--vendor/github.com/pion/stun/message.go588
1 files changed, 588 insertions, 0 deletions
diff --git a/vendor/github.com/pion/stun/message.go b/vendor/github.com/pion/stun/message.go
new file mode 100644
index 0000000..3819235
--- /dev/null
+++ b/vendor/github.com/pion/stun/message.go
@@ -0,0 +1,588 @@
+package stun
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+)
+
+const (
+ // magicCookie is fixed value that aids in distinguishing STUN packets
+ // from packets of other protocols when STUN is multiplexed with those
+ // other protocols on the same Port.
+ //
+ // The magic cookie field MUST contain the fixed value 0x2112A442 in
+ // network byte order.
+ //
+ // Defined in "STUN Message Structure", section 6.
+ magicCookie = 0x2112A442
+ attributeHeaderSize = 4
+ messageHeaderSize = 20
+
+ // TransactionIDSize is length of transaction id array (in bytes).
+ TransactionIDSize = 12 // 96 bit
+)
+
+// NewTransactionID returns new random transaction ID using crypto/rand
+// as source.
+func NewTransactionID() (b [TransactionIDSize]byte) {
+ readFullOrPanic(rand.Reader, b[:])
+ return b
+}
+
+// IsMessage returns true if b looks like STUN message.
+// Useful for multiplexing. IsMessage does not guarantee
+// that decoding will be successful.
+func IsMessage(b []byte) bool {
+ return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
+}
+
+// New returns *Message with pre-allocated Raw.
+func New() *Message {
+ const defaultRawCapacity = 120
+ return &Message{
+ Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
+ }
+}
+
+// ErrDecodeToNil occurs on Decode(data, nil) call.
+var ErrDecodeToNil = errors.New("attempt to decode to nil message")
+
+// Decode decodes Message from data to m, returning error if any.
+func Decode(data []byte, m *Message) error {
+ if m == nil {
+ return ErrDecodeToNil
+ }
+ m.Raw = append(m.Raw[:0], data...)
+ return m.Decode()
+}
+
+// Message represents a single STUN packet. It uses aggressive internal
+// buffering to enable zero-allocation encoding and decoding,
+// so there are some usage constraints:
+//
+// Message, its fields, results of m.Get or any attribute a.GetFrom
+// are valid only until Message.Raw is not modified.
+type Message struct {
+ Type MessageType
+ Length uint32 // len(Raw) not including header
+ TransactionID [TransactionIDSize]byte
+ Attributes Attributes
+ Raw []byte
+}
+
+// AddTo sets b.TransactionID to m.TransactionID.
+//
+// Implements Setter to aid in crafting responses.
+func (m *Message) AddTo(b *Message) error {
+ b.TransactionID = m.TransactionID
+ b.WriteTransactionID()
+ return nil
+}
+
+// NewTransactionID sets m.TransactionID to random value from crypto/rand
+// and returns error if any.
+func (m *Message) NewTransactionID() error {
+ _, err := io.ReadFull(rand.Reader, m.TransactionID[:])
+ if err == nil {
+ m.WriteTransactionID()
+ }
+ return err
+}
+
+func (m *Message) String() string {
+ tID := base64.StdEncoding.EncodeToString(m.TransactionID[:])
+ return fmt.Sprintf("%s l=%d attrs=%d id=%s", m.Type, m.Length, len(m.Attributes), tID)
+}
+
+// Reset resets Message, attributes and underlying buffer length.
+func (m *Message) Reset() {
+ m.Raw = m.Raw[:0]
+ m.Length = 0
+ m.Attributes = m.Attributes[:0]
+}
+
+// grow ensures that internal buffer has n length.
+func (m *Message) grow(n int) {
+ if len(m.Raw) >= n {
+ return
+ }
+ if cap(m.Raw) >= n {
+ m.Raw = m.Raw[:n]
+ return
+ }
+ m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
+}
+
+// Add appends new attribute to message. Not goroutine-safe.
+//
+// Value of attribute is copied to internal buffer so
+// it is safe to reuse v.
+func (m *Message) Add(t AttrType, v []byte) {
+ // Allocating buffer for TLV (type-length-value).
+ // T = t, L = len(v), V = v.
+ // m.Raw will look like:
+ // [0:20] <- message header
+ // [20:20+m.Length] <- existing message attributes
+ // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
+ // [first:last] <- same as previous
+ // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
+ // T L V
+ allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
+ first := messageHeaderSize + int(m.Length) // first byte number
+ last := first + allocSize // last byte number
+ m.grow(last) // growing cap(Raw) to fit TLV
+ m.Raw = m.Raw[:last] // now len(Raw) = last
+ m.Length += uint32(allocSize) // rendering length change
+
+ // Sub-slicing internal buffer to simplify encoding.
+ buf := m.Raw[first:last] // slice for TLV
+ value := buf[attributeHeaderSize:] // slice for V
+ attr := RawAttribute{
+ Type: t, // T
+ Length: uint16(len(v)), // L
+ Value: value, // V
+ }
+
+ // Encoding attribute TLV to allocated buffer.
+ bin.PutUint16(buf[0:2], attr.Type.Value()) // T
+ bin.PutUint16(buf[2:4], attr.Length) // L
+ copy(value, v) // V
+
+ // Checking that attribute value needs padding.
+ if attr.Length%padding != 0 {
+ // Performing padding.
+ bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
+ last += bytesToAdd
+ m.grow(last)
+ // setting all padding bytes to zero
+ // to prevent data leak from previous
+ // data in next bytesToAdd bytes
+ buf = m.Raw[last-bytesToAdd : last]
+ for i := range buf {
+ buf[i] = 0
+ }
+ m.Raw = m.Raw[:last] // increasing buffer length
+ m.Length += uint32(bytesToAdd) // rendering length change
+ }
+ m.Attributes = append(m.Attributes, attr)
+ m.WriteLength()
+}
+
+func attrSliceEqual(a, b Attributes) bool {
+ for _, attr := range a {
+ found := false
+ for _, attrB := range b {
+ if attrB.Type != attr.Type {
+ continue
+ }
+ if attrB.Equal(attr) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
+}
+
+func attrEqual(a, b Attributes) bool {
+ if a == nil && b == nil {
+ return true
+ }
+ if a == nil || b == nil {
+ return false
+ }
+ if len(a) != len(b) {
+ return false
+ }
+ if !attrSliceEqual(a, b) {
+ return false
+ }
+ if !attrSliceEqual(b, a) {
+ return false
+ }
+ return true
+}
+
+// Equal returns true if Message b equals to m.
+// Ignores m.Raw.
+func (m *Message) Equal(b *Message) bool {
+ if m == nil && b == nil {
+ return true
+ }
+ if m == nil || b == nil {
+ return false
+ }
+ if m.Type != b.Type {
+ return false
+ }
+ if m.TransactionID != b.TransactionID {
+ return false
+ }
+ if m.Length != b.Length {
+ return false
+ }
+ if !attrEqual(m.Attributes, b.Attributes) {
+ return false
+ }
+ return true
+}
+
+// WriteLength writes m.Length to m.Raw.
+func (m *Message) WriteLength() {
+ m.grow(4)
+ bin.PutUint16(m.Raw[2:4], uint16(m.Length))
+}
+
+// WriteHeader writes header to underlying buffer. Not goroutine-safe.
+func (m *Message) WriteHeader() {
+ m.grow(messageHeaderSize)
+ _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
+
+ m.WriteType()
+ m.WriteLength()
+ bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
+ copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
+}
+
+// WriteTransactionID writes m.TransactionID to m.Raw.
+func (m *Message) WriteTransactionID() {
+ copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
+}
+
+// WriteAttributes encodes all m.Attributes to m.
+func (m *Message) WriteAttributes() {
+ attributes := m.Attributes
+ m.Attributes = attributes[:0]
+ for _, a := range attributes {
+ m.Add(a.Type, a.Value)
+ }
+ m.Attributes = attributes
+}
+
+// WriteType writes m.Type to m.Raw.
+func (m *Message) WriteType() {
+ m.grow(2)
+ bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
+}
+
+// SetType sets m.Type and writes it to m.Raw.
+func (m *Message) SetType(t MessageType) {
+ m.Type = t
+ m.WriteType()
+}
+
+// Encode re-encodes message into m.Raw.
+func (m *Message) Encode() {
+ m.Raw = m.Raw[:0]
+ m.WriteHeader()
+ m.Length = 0
+ m.WriteAttributes()
+}
+
+// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
+// call result.
+func (m *Message) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(m.Raw)
+ return int64(n), err
+}
+
+// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
+// Decodes it and return error if any. If m.Raw is too small, will return
+// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
+//
+// Can return *DecodeErr while decoding too.
+func (m *Message) ReadFrom(r io.Reader) (int64, error) {
+ tBuf := m.Raw[:cap(m.Raw)]
+ var (
+ n int
+ err error
+ )
+ if n, err = r.Read(tBuf); err != nil {
+ return int64(n), err
+ }
+ m.Raw = tBuf[:n]
+ return int64(n), m.Decode()
+}
+
+// ErrUnexpectedHeaderEOF means that there were not enough bytes in
+// m.Raw to read header.
+var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
+
+// Decode decodes m.Raw into m.
+func (m *Message) Decode() error {
+ // decoding message header
+ buf := m.Raw
+ if len(buf) < messageHeaderSize {
+ return ErrUnexpectedHeaderEOF
+ }
+ var (
+ t = bin.Uint16(buf[0:2]) // first 2 bytes
+ size = int(bin.Uint16(buf[2:4])) // second 2 bytes
+ cookie = bin.Uint32(buf[4:8]) // last 4 bytes
+ fullSize = messageHeaderSize + size // len(m.Raw)
+ )
+ if cookie != magicCookie {
+ msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
+ return newDecodeErr("message", "cookie", msg)
+ }
+ if len(buf) < fullSize {
+ msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
+ return newAttrDecodeErr("message", msg)
+ }
+ // saving header data
+ m.Type.ReadValue(t)
+ m.Length = uint32(size)
+ copy(m.TransactionID[:], buf[8:messageHeaderSize])
+
+ m.Attributes = m.Attributes[:0]
+ var (
+ offset = 0
+ b = buf[messageHeaderSize:fullSize]
+ )
+ for offset < size {
+ // checking that we have enough bytes to read header
+ if len(b) < attributeHeaderSize {
+ msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
+ return newAttrDecodeErr("header", msg)
+ }
+ var (
+ a = RawAttribute{
+ Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
+ Length: bin.Uint16(b[2:4]), // second 2 bytes
+ }
+ aL = int(a.Length) // attribute length
+ aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
+ )
+ b = b[attributeHeaderSize:] // slicing again to simplify value read
+ offset += attributeHeaderSize
+ if len(b) < aBuffL { // checking size
+ msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
+ return newAttrDecodeErr("value", msg)
+ }
+ a.Value = b[:aL]
+ offset += aBuffL
+ b = b[aBuffL:]
+
+ m.Attributes = append(m.Attributes, a)
+ }
+ return nil
+}
+
+// Write decodes message and return error if any.
+//
+// Any error is unrecoverable, but message could be partially decoded.
+func (m *Message) Write(tBuf []byte) (int, error) {
+ m.Raw = append(m.Raw[:0], tBuf...)
+ return len(tBuf), m.Decode()
+}
+
+// CloneTo clones m to b securing any further m mutations.
+func (m *Message) CloneTo(b *Message) error {
+ // TODO(ar): implement low-level copy.
+ b.Raw = append(b.Raw[:0], m.Raw...)
+ return b.Decode()
+}
+
+// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
+type MessageClass byte
+
+// Possible values for message class in STUN Message Type.
+const (
+ ClassRequest MessageClass = 0x00 // 0b00
+ ClassIndication MessageClass = 0x01 // 0b01
+ ClassSuccessResponse MessageClass = 0x02 // 0b10
+ ClassErrorResponse MessageClass = 0x03 // 0b11
+)
+
+// Common STUN message types.
+var (
+ // Binding request message type.
+ BindingRequest = NewType(MethodBinding, ClassRequest)
+ // Binding success response message type
+ BindingSuccess = NewType(MethodBinding, ClassSuccessResponse)
+ // Binding error response message type.
+ BindingError = NewType(MethodBinding, ClassErrorResponse)
+)
+
+func (c MessageClass) String() string {
+ switch c {
+ case ClassRequest:
+ return "request"
+ case ClassIndication:
+ return "indication"
+ case ClassSuccessResponse:
+ return "success response"
+ case ClassErrorResponse:
+ return "error response"
+ default:
+ panic("unknown message class") // nolint: never happens unless wrongly casted
+ }
+}
+
+// Method is uint16 representation of 12-bit STUN method.
+type Method uint16
+
+// Possible methods for STUN Message.
+const (
+ MethodBinding Method = 0x001
+ MethodAllocate Method = 0x003
+ MethodRefresh Method = 0x004
+ MethodSend Method = 0x006
+ MethodData Method = 0x007
+ MethodCreatePermission Method = 0x008
+ MethodChannelBind Method = 0x009
+)
+
+// Methods from RFC 6062.
+const (
+ MethodConnect Method = 0x000a
+ MethodConnectionBind Method = 0x000b
+ MethodConnectionAttempt Method = 0x000c
+)
+
+var methodName = map[Method]string{
+ MethodBinding: "Binding",
+ MethodAllocate: "Allocate",
+ MethodRefresh: "Refresh",
+ MethodSend: "Send",
+ MethodData: "Data",
+ MethodCreatePermission: "CreatePermission",
+ MethodChannelBind: "ChannelBind",
+
+ // RFC 6062.
+ MethodConnect: "Connect",
+ MethodConnectionBind: "ConnectionBind",
+ MethodConnectionAttempt: "ConnectionAttempt",
+}
+
+func (m Method) String() string {
+ s, ok := methodName[m]
+ if !ok {
+ // Falling back to hex representation.
+ s = fmt.Sprintf("0x%x", uint16(m))
+ }
+ return s
+}
+
+// MessageType is STUN Message Type Field.
+type MessageType struct {
+ Method Method // e.g. binding
+ Class MessageClass // e.g. request
+}
+
+// AddTo sets m type to t.
+func (t MessageType) AddTo(m *Message) error {
+ m.SetType(t)
+ return nil
+}
+
+// NewType returns new message type with provided method and class.
+func NewType(method Method, class MessageClass) MessageType {
+ return MessageType{
+ Method: method,
+ Class: class,
+ }
+}
+
+const (
+ methodABits = 0xf // 0b0000000000001111
+ methodBBits = 0x70 // 0b0000000001110000
+ methodDBits = 0xf80 // 0b0000111110000000
+
+ methodBShift = 1
+ methodDShift = 2
+
+ firstBit = 0x1
+ secondBit = 0x2
+
+ c0Bit = firstBit
+ c1Bit = secondBit
+
+ classC0Shift = 4
+ classC1Shift = 7
+)
+
+// Value returns bit representation of messageType.
+func (t MessageType) Value() uint16 {
+ // 0 1
+ // 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+ // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
+ // |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
+ // |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
+ // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
+ // Figure 3: Format of STUN Message Type Field
+
+ // Warning: Abandon all hope ye who enter here.
+ // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
+ m := uint16(t.Method)
+ a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
+ b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
+ d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
+
+ // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
+ m = a + (b << methodBShift) + (d << methodDShift)
+
+ // C0 is zero bit of C, C1 is first bit.
+ // C0 = C * 0b01, C1 = (C * 0b10) >> 1
+ // Ct = C0 << 4 + C1 << 8.
+ // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
+ // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
+ // (see figure 3).
+ c := uint16(t.Class)
+ c0 := (c & c0Bit) << classC0Shift
+ c1 := (c & c1Bit) << classC1Shift
+ class := c0 + c1
+
+ return m + class
+}
+
+// ReadValue decodes uint16 into MessageType.
+func (t *MessageType) ReadValue(v uint16) {
+ // Decoding class.
+ // We are taking first bit from v >> 4 and second from v >> 7.
+ c0 := (v >> classC0Shift) & c0Bit
+ c1 := (v >> classC1Shift) & c1Bit
+ class := c0 + c1
+ t.Class = MessageClass(class)
+
+ // Decoding method.
+ a := v & methodABits // A(M0-M3)
+ b := (v >> methodBShift) & methodBBits // B(M4-M6)
+ d := (v >> methodDShift) & methodDBits // D(M7-M11)
+ m := a + b + d
+ t.Method = Method(m)
+}
+
+func (t MessageType) String() string {
+ return fmt.Sprintf("%s %s", t.Method, t.Class)
+}
+
+// Contains return true if message contain t attribute.
+func (m *Message) Contains(t AttrType) bool {
+ for _, a := range m.Attributes {
+ if a.Type == t {
+ return true
+ }
+ }
+ return false
+}
+
+type transactionIDValueSetter [TransactionIDSize]byte
+
+// NewTransactionIDSetter returns new Setter that sets message transaction id
+// to provided value.
+func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
+ return transactionIDValueSetter(value)
+}
+
+func (t transactionIDValueSetter) AddTo(m *Message) error {
+ m.TransactionID = t
+ m.WriteTransactionID()
+ return nil
+}