diff options
Diffstat (limited to 'vendor/github.com/pion/stun/message.go')
-rw-r--r-- | vendor/github.com/pion/stun/message.go | 588 |
1 files changed, 588 insertions, 0 deletions
diff --git a/vendor/github.com/pion/stun/message.go b/vendor/github.com/pion/stun/message.go new file mode 100644 index 0000000..3819235 --- /dev/null +++ b/vendor/github.com/pion/stun/message.go @@ -0,0 +1,588 @@ +package stun + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +const ( + // magicCookie is fixed value that aids in distinguishing STUN packets + // from packets of other protocols when STUN is multiplexed with those + // other protocols on the same Port. + // + // The magic cookie field MUST contain the fixed value 0x2112A442 in + // network byte order. + // + // Defined in "STUN Message Structure", section 6. + magicCookie = 0x2112A442 + attributeHeaderSize = 4 + messageHeaderSize = 20 + + // TransactionIDSize is length of transaction id array (in bytes). + TransactionIDSize = 12 // 96 bit +) + +// NewTransactionID returns new random transaction ID using crypto/rand +// as source. +func NewTransactionID() (b [TransactionIDSize]byte) { + readFullOrPanic(rand.Reader, b[:]) + return b +} + +// IsMessage returns true if b looks like STUN message. +// Useful for multiplexing. IsMessage does not guarantee +// that decoding will be successful. +func IsMessage(b []byte) bool { + return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie +} + +// New returns *Message with pre-allocated Raw. +func New() *Message { + const defaultRawCapacity = 120 + return &Message{ + Raw: make([]byte, messageHeaderSize, defaultRawCapacity), + } +} + +// ErrDecodeToNil occurs on Decode(data, nil) call. +var ErrDecodeToNil = errors.New("attempt to decode to nil message") + +// Decode decodes Message from data to m, returning error if any. +func Decode(data []byte, m *Message) error { + if m == nil { + return ErrDecodeToNil + } + m.Raw = append(m.Raw[:0], data...) + return m.Decode() +} + +// Message represents a single STUN packet. It uses aggressive internal +// buffering to enable zero-allocation encoding and decoding, +// so there are some usage constraints: +// +// Message, its fields, results of m.Get or any attribute a.GetFrom +// are valid only until Message.Raw is not modified. +type Message struct { + Type MessageType + Length uint32 // len(Raw) not including header + TransactionID [TransactionIDSize]byte + Attributes Attributes + Raw []byte +} + +// AddTo sets b.TransactionID to m.TransactionID. +// +// Implements Setter to aid in crafting responses. +func (m *Message) AddTo(b *Message) error { + b.TransactionID = m.TransactionID + b.WriteTransactionID() + return nil +} + +// NewTransactionID sets m.TransactionID to random value from crypto/rand +// and returns error if any. +func (m *Message) NewTransactionID() error { + _, err := io.ReadFull(rand.Reader, m.TransactionID[:]) + if err == nil { + m.WriteTransactionID() + } + return err +} + +func (m *Message) String() string { + tID := base64.StdEncoding.EncodeToString(m.TransactionID[:]) + return fmt.Sprintf("%s l=%d attrs=%d id=%s", m.Type, m.Length, len(m.Attributes), tID) +} + +// Reset resets Message, attributes and underlying buffer length. +func (m *Message) Reset() { + m.Raw = m.Raw[:0] + m.Length = 0 + m.Attributes = m.Attributes[:0] +} + +// grow ensures that internal buffer has n length. +func (m *Message) grow(n int) { + if len(m.Raw) >= n { + return + } + if cap(m.Raw) >= n { + m.Raw = m.Raw[:n] + return + } + m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...) +} + +// Add appends new attribute to message. Not goroutine-safe. +// +// Value of attribute is copied to internal buffer so +// it is safe to reuse v. +func (m *Message) Add(t AttrType, v []byte) { + // Allocating buffer for TLV (type-length-value). + // T = t, L = len(v), V = v. + // m.Raw will look like: + // [0:20] <- message header + // [20:20+m.Length] <- existing message attributes + // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV + // [first:last] <- same as previous + // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer + // T L V + allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V) + first := messageHeaderSize + int(m.Length) // first byte number + last := first + allocSize // last byte number + m.grow(last) // growing cap(Raw) to fit TLV + m.Raw = m.Raw[:last] // now len(Raw) = last + m.Length += uint32(allocSize) // rendering length change + + // Sub-slicing internal buffer to simplify encoding. + buf := m.Raw[first:last] // slice for TLV + value := buf[attributeHeaderSize:] // slice for V + attr := RawAttribute{ + Type: t, // T + Length: uint16(len(v)), // L + Value: value, // V + } + + // Encoding attribute TLV to allocated buffer. + bin.PutUint16(buf[0:2], attr.Type.Value()) // T + bin.PutUint16(buf[2:4], attr.Length) // L + copy(value, v) // V + + // Checking that attribute value needs padding. + if attr.Length%padding != 0 { + // Performing padding. + bytesToAdd := nearestPaddedValueLength(len(v)) - len(v) + last += bytesToAdd + m.grow(last) + // setting all padding bytes to zero + // to prevent data leak from previous + // data in next bytesToAdd bytes + buf = m.Raw[last-bytesToAdd : last] + for i := range buf { + buf[i] = 0 + } + m.Raw = m.Raw[:last] // increasing buffer length + m.Length += uint32(bytesToAdd) // rendering length change + } + m.Attributes = append(m.Attributes, attr) + m.WriteLength() +} + +func attrSliceEqual(a, b Attributes) bool { + for _, attr := range a { + found := false + for _, attrB := range b { + if attrB.Type != attr.Type { + continue + } + if attrB.Equal(attr) { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +func attrEqual(a, b Attributes) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + if len(a) != len(b) { + return false + } + if !attrSliceEqual(a, b) { + return false + } + if !attrSliceEqual(b, a) { + return false + } + return true +} + +// Equal returns true if Message b equals to m. +// Ignores m.Raw. +func (m *Message) Equal(b *Message) bool { + if m == nil && b == nil { + return true + } + if m == nil || b == nil { + return false + } + if m.Type != b.Type { + return false + } + if m.TransactionID != b.TransactionID { + return false + } + if m.Length != b.Length { + return false + } + if !attrEqual(m.Attributes, b.Attributes) { + return false + } + return true +} + +// WriteLength writes m.Length to m.Raw. +func (m *Message) WriteLength() { + m.grow(4) + bin.PutUint16(m.Raw[2:4], uint16(m.Length)) +} + +// WriteHeader writes header to underlying buffer. Not goroutine-safe. +func (m *Message) WriteHeader() { + m.grow(messageHeaderSize) + _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below + + m.WriteType() + m.WriteLength() + bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie + copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID +} + +// WriteTransactionID writes m.TransactionID to m.Raw. +func (m *Message) WriteTransactionID() { + copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID +} + +// WriteAttributes encodes all m.Attributes to m. +func (m *Message) WriteAttributes() { + attributes := m.Attributes + m.Attributes = attributes[:0] + for _, a := range attributes { + m.Add(a.Type, a.Value) + } + m.Attributes = attributes +} + +// WriteType writes m.Type to m.Raw. +func (m *Message) WriteType() { + m.grow(2) + bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type +} + +// SetType sets m.Type and writes it to m.Raw. +func (m *Message) SetType(t MessageType) { + m.Type = t + m.WriteType() +} + +// Encode re-encodes message into m.Raw. +func (m *Message) Encode() { + m.Raw = m.Raw[:0] + m.WriteHeader() + m.Length = 0 + m.WriteAttributes() +} + +// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning +// call result. +func (m *Message) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.Raw) + return int64(n), err +} + +// ReadFrom implements ReaderFrom. Reads message from r into m.Raw, +// Decodes it and return error if any. If m.Raw is too small, will return +// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. +// +// Can return *DecodeErr while decoding too. +func (m *Message) ReadFrom(r io.Reader) (int64, error) { + tBuf := m.Raw[:cap(m.Raw)] + var ( + n int + err error + ) + if n, err = r.Read(tBuf); err != nil { + return int64(n), err + } + m.Raw = tBuf[:n] + return int64(n), m.Decode() +} + +// ErrUnexpectedHeaderEOF means that there were not enough bytes in +// m.Raw to read header. +var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header") + +// Decode decodes m.Raw into m. +func (m *Message) Decode() error { + // decoding message header + buf := m.Raw + if len(buf) < messageHeaderSize { + return ErrUnexpectedHeaderEOF + } + var ( + t = bin.Uint16(buf[0:2]) // first 2 bytes + size = int(bin.Uint16(buf[2:4])) // second 2 bytes + cookie = bin.Uint32(buf[4:8]) // last 4 bytes + fullSize = messageHeaderSize + size // len(m.Raw) + ) + if cookie != magicCookie { + msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie) + return newDecodeErr("message", "cookie", msg) + } + if len(buf) < fullSize { + msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize) + return newAttrDecodeErr("message", msg) + } + // saving header data + m.Type.ReadValue(t) + m.Length = uint32(size) + copy(m.TransactionID[:], buf[8:messageHeaderSize]) + + m.Attributes = m.Attributes[:0] + var ( + offset = 0 + b = buf[messageHeaderSize:fullSize] + ) + for offset < size { + // checking that we have enough bytes to read header + if len(b) < attributeHeaderSize { + msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize) + return newAttrDecodeErr("header", msg) + } + var ( + a = RawAttribute{ + Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes + Length: bin.Uint16(b[2:4]), // second 2 bytes + } + aL = int(a.Length) // attribute length + aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) + ) + b = b[attributeHeaderSize:] // slicing again to simplify value read + offset += attributeHeaderSize + if len(b) < aBuffL { // checking size + msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type) + return newAttrDecodeErr("value", msg) + } + a.Value = b[:aL] + offset += aBuffL + b = b[aBuffL:] + + m.Attributes = append(m.Attributes, a) + } + return nil +} + +// Write decodes message and return error if any. +// +// Any error is unrecoverable, but message could be partially decoded. +func (m *Message) Write(tBuf []byte) (int, error) { + m.Raw = append(m.Raw[:0], tBuf...) + return len(tBuf), m.Decode() +} + +// CloneTo clones m to b securing any further m mutations. +func (m *Message) CloneTo(b *Message) error { + // TODO(ar): implement low-level copy. + b.Raw = append(b.Raw[:0], m.Raw...) + return b.Decode() +} + +// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. +type MessageClass byte + +// Possible values for message class in STUN Message Type. +const ( + ClassRequest MessageClass = 0x00 // 0b00 + ClassIndication MessageClass = 0x01 // 0b01 + ClassSuccessResponse MessageClass = 0x02 // 0b10 + ClassErrorResponse MessageClass = 0x03 // 0b11 +) + +// Common STUN message types. +var ( + // Binding request message type. + BindingRequest = NewType(MethodBinding, ClassRequest) + // Binding success response message type + BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) + // Binding error response message type. + BindingError = NewType(MethodBinding, ClassErrorResponse) +) + +func (c MessageClass) String() string { + switch c { + case ClassRequest: + return "request" + case ClassIndication: + return "indication" + case ClassSuccessResponse: + return "success response" + case ClassErrorResponse: + return "error response" + default: + panic("unknown message class") // nolint: never happens unless wrongly casted + } +} + +// Method is uint16 representation of 12-bit STUN method. +type Method uint16 + +// Possible methods for STUN Message. +const ( + MethodBinding Method = 0x001 + MethodAllocate Method = 0x003 + MethodRefresh Method = 0x004 + MethodSend Method = 0x006 + MethodData Method = 0x007 + MethodCreatePermission Method = 0x008 + MethodChannelBind Method = 0x009 +) + +// Methods from RFC 6062. +const ( + MethodConnect Method = 0x000a + MethodConnectionBind Method = 0x000b + MethodConnectionAttempt Method = 0x000c +) + +var methodName = map[Method]string{ + MethodBinding: "Binding", + MethodAllocate: "Allocate", + MethodRefresh: "Refresh", + MethodSend: "Send", + MethodData: "Data", + MethodCreatePermission: "CreatePermission", + MethodChannelBind: "ChannelBind", + + // RFC 6062. + MethodConnect: "Connect", + MethodConnectionBind: "ConnectionBind", + MethodConnectionAttempt: "ConnectionAttempt", +} + +func (m Method) String() string { + s, ok := methodName[m] + if !ok { + // Falling back to hex representation. + s = fmt.Sprintf("0x%x", uint16(m)) + } + return s +} + +// MessageType is STUN Message Type Field. +type MessageType struct { + Method Method // e.g. binding + Class MessageClass // e.g. request +} + +// AddTo sets m type to t. +func (t MessageType) AddTo(m *Message) error { + m.SetType(t) + return nil +} + +// NewType returns new message type with provided method and class. +func NewType(method Method, class MessageClass) MessageType { + return MessageType{ + Method: method, + Class: class, + } +} + +const ( + methodABits = 0xf // 0b0000000000001111 + methodBBits = 0x70 // 0b0000000001110000 + methodDBits = 0xf80 // 0b0000111110000000 + + methodBShift = 1 + methodDShift = 2 + + firstBit = 0x1 + secondBit = 0x2 + + c0Bit = firstBit + c1Bit = secondBit + + classC0Shift = 4 + classC1Shift = 7 +) + +// Value returns bit representation of messageType. +func (t MessageType) Value() uint16 { + // 0 1 + // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // Figure 3: Format of STUN Message Type Field + + // Warning: Abandon all hope ye who enter here. + // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). + m := uint16(t.Method) + a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) + b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) + d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) + + // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). + m = a + (b << methodBShift) + (d << methodDShift) + + // C0 is zero bit of C, C1 is first bit. + // C0 = C * 0b01, C1 = (C * 0b10) >> 1 + // Ct = C0 << 4 + C1 << 8. + // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" + // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions + // (see figure 3). + c := uint16(t.Class) + c0 := (c & c0Bit) << classC0Shift + c1 := (c & c1Bit) << classC1Shift + class := c0 + c1 + + return m + class +} + +// ReadValue decodes uint16 into MessageType. +func (t *MessageType) ReadValue(v uint16) { + // Decoding class. + // We are taking first bit from v >> 4 and second from v >> 7. + c0 := (v >> classC0Shift) & c0Bit + c1 := (v >> classC1Shift) & c1Bit + class := c0 + c1 + t.Class = MessageClass(class) + + // Decoding method. + a := v & methodABits // A(M0-M3) + b := (v >> methodBShift) & methodBBits // B(M4-M6) + d := (v >> methodDShift) & methodDBits // D(M7-M11) + m := a + b + d + t.Method = Method(m) +} + +func (t MessageType) String() string { + return fmt.Sprintf("%s %s", t.Method, t.Class) +} + +// Contains return true if message contain t attribute. +func (m *Message) Contains(t AttrType) bool { + for _, a := range m.Attributes { + if a.Type == t { + return true + } + } + return false +} + +type transactionIDValueSetter [TransactionIDSize]byte + +// NewTransactionIDSetter returns new Setter that sets message transaction id +// to provided value. +func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter { + return transactionIDValueSetter(value) +} + +func (t transactionIDValueSetter) AddTo(m *Message) error { + m.TransactionID = t + m.WriteTransactionID() + return nil +} |