diff options
Diffstat (limited to 'framing')
-rw-r--r-- | framing/framing.go | 68 | ||||
-rw-r--r-- | framing/framing_test.go | 41 |
2 files changed, 52 insertions, 57 deletions
diff --git a/framing/framing.go b/framing/framing.go index 710bc1f..c57189d 100644 --- a/framing/framing.go +++ b/framing/framing.go @@ -61,6 +61,7 @@ import ( "errors" "fmt" "hash" + "io" "code.google.com/p/go.crypto/nacl/secretbox" @@ -172,40 +173,41 @@ func NewEncoder(key []byte) *Encoder { } // Encode encodes a single frame worth of payload and returns the encoded -// length and the resulting frame. InvalidPayloadLengthError is recoverable, -// all other errors MUST be treated as fatal and the session aborted. -func (encoder *Encoder) Encode(payload []byte) (int, []byte, error) { +// length. InvalidPayloadLengthError is recoverable, all other errors MUST be +// treated as fatal and the session aborted. +func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) { payloadLen := len(payload) if MaximumFramePayloadLength < payloadLen { - return 0, nil, InvalidPayloadLengthError(payloadLen) + return 0, InvalidPayloadLengthError(payloadLen) + } + if len(frame) < payloadLen + FrameOverhead { + return 0, io.ErrShortBuffer } // Generate a new nonce. var nonce [nonceLength]byte - err := encoder.nonce.bytes(&nonce) + err = encoder.nonce.bytes(&nonce) if err != nil { - return 0, nil, err + return 0, err } encoder.nonce.counter++ // Encrypt and MAC payload. - var box []byte - box = secretbox.Seal(nil, payload, &nonce, &encoder.key) + box := secretbox.Seal(frame[:lengthLength], payload, &nonce, &encoder.key) // Obfuscate the length. - length := uint16(len(box)) + length := uint16(len(box)-lengthLength) encoder.sip.Write(nonce[:]) lengthMask := encoder.sip.Sum(nil) encoder.sip.Reset() length ^= binary.BigEndian.Uint16(lengthMask) - var obfsLen [lengthLength]byte - binary.BigEndian.PutUint16(obfsLen[:], length) + binary.BigEndian.PutUint16(frame[:2], length) // Prepare the next obfsucator. - encoder.sip.Write(box) + encoder.sip.Write(box[lengthLength:]) // Return the frame. - return payloadLen + FrameOverhead, append(obfsLen[:], box...), nil + return len(box), nil } // Decoder is a frame decoder instance. @@ -233,23 +235,23 @@ func NewDecoder(key []byte) *Decoder { return decoder } -// Decode decodes a stream of data and returns the length and decoded frame if -// any. ErrAgain is a temporary failure, all other errors MUST be treated as -// fatal and the session aborted. -func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) { +// Decode decodes a stream of data and returns the length if any. ErrAgain is +// a temporary failure, all other errors MUST be treated as fatal and the +// session aborted. +func (decoder *Decoder) Decode(data []byte, frames *bytes.Buffer) (int, error) { // A length of 0 indicates that we do not know how big the next frame is // going to be. if decoder.nextLength == 0 { // Attempt to pull out the next frame length. - if lengthLength > data.Len() { - return 0, nil, ErrAgain + if lengthLength > frames.Len() { + return 0, ErrAgain } // Remove the length field from the buffer. var obfsLen [lengthLength]byte - n, err := data.Read(obfsLen[:]) + n, err := frames.Read(obfsLen[:]) if err != nil { - return 0, nil, err + return 0, err } else if n != lengthLength { // Should *NEVER* happen, since at least 2 bytes exist. panic(fmt.Sprintf("BUG: Failed to read obfuscated length: %d", n)) @@ -258,7 +260,7 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) { // Derive the nonce the peer used. err = decoder.nonce.bytes(&decoder.nextNonce) if err != nil { - return 0, nil, err + return 0, err } // Deobfuscate the length field. @@ -268,36 +270,36 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) { decoder.sip.Reset() length ^= binary.BigEndian.Uint16(lengthMask) if maxFrameLength < length || minFrameLength > length { - return 0, nil, InvalidFrameLengthError(length) + return 0, InvalidFrameLengthError(length) } decoder.nextLength = length } - if int(decoder.nextLength) > data.Len() { - return 0, nil, ErrAgain + if int(decoder.nextLength) > frames.Len() { + return 0, ErrAgain } // Unseal the frame. - box := make([]byte, decoder.nextLength) - n, err := data.Read(box) + var box [maxFrameLength]byte + n, err := frames.Read(box[:decoder.nextLength]) if err != nil { - return 0, nil, err + return 0, err } else if n != int(decoder.nextLength) { - // Should *NEVER* happen, since at least 2 bytes exist. + // Should *NEVER* happen, since the length is checked. panic(fmt.Sprintf("BUG: Failed to read secretbox, got %d, should have %d", n, decoder.nextLength)) } - out, ok := secretbox.Open(nil, box, &decoder.nextNonce, &decoder.key) + out, ok := secretbox.Open(data[:0], box[:n], &decoder.nextNonce, &decoder.key) if !ok { - return 0, nil, ErrTagMismatch + return 0, ErrTagMismatch } - decoder.sip.Write(box) + decoder.sip.Write(box[:n]) // Clean up and prepare for the next frame. decoder.nextLength = 0 decoder.nonce.counter++ - return len(out), out, nil + return len(out), nil } /* vim :set ts=4 sw=4 sts=4 noet : */ diff --git a/framing/framing_test.go b/framing/framing_test.go index 221ea5e..08f5f17 100644 --- a/framing/framing_test.go +++ b/framing/framing_test.go @@ -69,7 +69,8 @@ func TestEncoder_Encode(t *testing.T) { buf := make([]byte, MaximumFramePayloadLength) _, _ = rand.Read(buf) // YOLO for i := 0; i <= MaximumFramePayloadLength; i++ { - n, frame, err := encoder.Encode(buf[0:i]) + var frame [MaximumSegmentLength]byte + n, err := encoder.Encode(frame[:], buf[0:i]) if err != nil { t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err) } @@ -77,10 +78,6 @@ func TestEncoder_Encode(t *testing.T) { t.Fatalf("Unexpected encoded framesize: %d, expecting %d", n, i+ FrameOverhead) } - if len(frame) != n { - t.Fatalf("Encoded frame length/rval mismatch: %d != %d", - len(frame), n) - } } } @@ -88,9 +85,10 @@ func TestEncoder_Encode(t *testing.T) { func TestEncoder_Encode_Oversize(t *testing.T) { encoder := newEncoder(t) - buf := make([]byte, MaximumFramePayloadLength+1) - _, _ = rand.Read(buf) // YOLO - _, _, err := encoder.Encode(buf) + var frame [MaximumSegmentLength]byte + var buf [MaximumFramePayloadLength+1]byte + _, _ = rand.Read(buf[:]) // YOLO + _, err := encoder.Encode(frame[:], buf[:]) if _, ok := err.(InvalidPayloadLengthError); !ok { t.Error("Encoder.encode() returned unexpected error:", err) } @@ -112,10 +110,11 @@ func TestDecoder_Decode(t *testing.T) { encoder := NewEncoder(key) decoder := NewDecoder(key) - buf := make([]byte, MaximumFramePayloadLength) - _, _ = rand.Read(buf) // YOLO + var buf [MaximumFramePayloadLength]byte + _, _ = rand.Read(buf[:]) // YOLO for i := 0; i <= MaximumFramePayloadLength; i++ { - encLen, frame, err := encoder.Encode(buf[0:i]) + var frame [MaximumSegmentLength]byte + encLen, err := encoder.Encode(frame[:], buf[0:i]) if err != nil { t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err) } @@ -123,12 +122,10 @@ func TestDecoder_Decode(t *testing.T) { t.Fatalf("Unexpected encoded framesize: %d, expecting %d", encLen, i+FrameOverhead) } - if len(frame) != encLen { - t.Fatalf("Encoded frame length/rval mismatch: %d != %d", - len(frame), encLen) - } - decLen, decoded, err := decoder.Decode(bytes.NewBuffer(frame)) + var decoded [MaximumFramePayloadLength]byte + + decLen, err := decoder.Decode(decoded[:], bytes.NewBuffer(frame[:encLen])) if err != nil { t.Fatalf("Decoder.decode([%d]byte), failed: %s", i, err) } @@ -136,13 +133,8 @@ func TestDecoder_Decode(t *testing.T) { t.Fatalf("Unexpected decoded framesize: %d, expecting %d", decLen, i) } - if len(decoded) != i { - t.Fatalf("Encoded frame length/rval mismatch: %d != %d", - len(decoded), i) - - } - if 0 != bytes.Compare(decoded, buf[0:i]) { + if 0 != bytes.Compare(decoded[:decLen], buf[:i]) { t.Fatalf("Frame %d does not match encoder input", i) } } @@ -152,6 +144,7 @@ func TestDecoder_Decode(t *testing.T) { // of payload. func BenchmarkEncoder_Encode(b *testing.B) { var chopBuf [MaximumFramePayloadLength]byte + var frame [MaximumSegmentLength]byte payload := make([]byte, 1024*1024) encoder := NewEncoder(generateRandomKey()) b.ResetTimer() @@ -165,8 +158,8 @@ func BenchmarkEncoder_Encode(b *testing.B) { b.Fatal("buffer.Read() failed:", err) } - n, frame, err := encoder.Encode(chopBuf[:n]) - transfered += len(frame) - FrameOverhead + n, err = encoder.Encode(frame[:], chopBuf[:n]) + transfered += n - FrameOverhead } if transfered != len(payload) { b.Fatalf("Transfered length mismatch: %d != %d", transfered, |