summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--framing/framing.go68
-rw-r--r--framing/framing_test.go41
-rw-r--r--packet.go21
3 files changed, 63 insertions, 67 deletions
diff --git a/framing/framing.go b/framing/framing.go
index 710bc1f..c57189d 100644
--- a/framing/framing.go
+++ b/framing/framing.go
@@ -61,6 +61,7 @@ import (
"errors"
"fmt"
"hash"
+ "io"
"code.google.com/p/go.crypto/nacl/secretbox"
@@ -172,40 +173,41 @@ func NewEncoder(key []byte) *Encoder {
}
// Encode encodes a single frame worth of payload and returns the encoded
-// length and the resulting frame. InvalidPayloadLengthError is recoverable,
-// all other errors MUST be treated as fatal and the session aborted.
-func (encoder *Encoder) Encode(payload []byte) (int, []byte, error) {
+// length. InvalidPayloadLengthError is recoverable, all other errors MUST be
+// treated as fatal and the session aborted.
+func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) {
payloadLen := len(payload)
if MaximumFramePayloadLength < payloadLen {
- return 0, nil, InvalidPayloadLengthError(payloadLen)
+ return 0, InvalidPayloadLengthError(payloadLen)
+ }
+ if len(frame) < payloadLen + FrameOverhead {
+ return 0, io.ErrShortBuffer
}
// Generate a new nonce.
var nonce [nonceLength]byte
- err := encoder.nonce.bytes(&nonce)
+ err = encoder.nonce.bytes(&nonce)
if err != nil {
- return 0, nil, err
+ return 0, err
}
encoder.nonce.counter++
// Encrypt and MAC payload.
- var box []byte
- box = secretbox.Seal(nil, payload, &nonce, &encoder.key)
+ box := secretbox.Seal(frame[:lengthLength], payload, &nonce, &encoder.key)
// Obfuscate the length.
- length := uint16(len(box))
+ length := uint16(len(box)-lengthLength)
encoder.sip.Write(nonce[:])
lengthMask := encoder.sip.Sum(nil)
encoder.sip.Reset()
length ^= binary.BigEndian.Uint16(lengthMask)
- var obfsLen [lengthLength]byte
- binary.BigEndian.PutUint16(obfsLen[:], length)
+ binary.BigEndian.PutUint16(frame[:2], length)
// Prepare the next obfsucator.
- encoder.sip.Write(box)
+ encoder.sip.Write(box[lengthLength:])
// Return the frame.
- return payloadLen + FrameOverhead, append(obfsLen[:], box...), nil
+ return len(box), nil
}
// Decoder is a frame decoder instance.
@@ -233,23 +235,23 @@ func NewDecoder(key []byte) *Decoder {
return decoder
}
-// Decode decodes a stream of data and returns the length and decoded frame if
-// any. ErrAgain is a temporary failure, all other errors MUST be treated as
-// fatal and the session aborted.
-func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
+// Decode decodes a stream of data and returns the length if any. ErrAgain is
+// a temporary failure, all other errors MUST be treated as fatal and the
+// session aborted.
+func (decoder *Decoder) Decode(data []byte, frames *bytes.Buffer) (int, error) {
// A length of 0 indicates that we do not know how big the next frame is
// going to be.
if decoder.nextLength == 0 {
// Attempt to pull out the next frame length.
- if lengthLength > data.Len() {
- return 0, nil, ErrAgain
+ if lengthLength > frames.Len() {
+ return 0, ErrAgain
}
// Remove the length field from the buffer.
var obfsLen [lengthLength]byte
- n, err := data.Read(obfsLen[:])
+ n, err := frames.Read(obfsLen[:])
if err != nil {
- return 0, nil, err
+ return 0, err
} else if n != lengthLength {
// Should *NEVER* happen, since at least 2 bytes exist.
panic(fmt.Sprintf("BUG: Failed to read obfuscated length: %d", n))
@@ -258,7 +260,7 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
// Derive the nonce the peer used.
err = decoder.nonce.bytes(&decoder.nextNonce)
if err != nil {
- return 0, nil, err
+ return 0, err
}
// Deobfuscate the length field.
@@ -268,36 +270,36 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
decoder.sip.Reset()
length ^= binary.BigEndian.Uint16(lengthMask)
if maxFrameLength < length || minFrameLength > length {
- return 0, nil, InvalidFrameLengthError(length)
+ return 0, InvalidFrameLengthError(length)
}
decoder.nextLength = length
}
- if int(decoder.nextLength) > data.Len() {
- return 0, nil, ErrAgain
+ if int(decoder.nextLength) > frames.Len() {
+ return 0, ErrAgain
}
// Unseal the frame.
- box := make([]byte, decoder.nextLength)
- n, err := data.Read(box)
+ var box [maxFrameLength]byte
+ n, err := frames.Read(box[:decoder.nextLength])
if err != nil {
- return 0, nil, err
+ return 0, err
} else if n != int(decoder.nextLength) {
- // Should *NEVER* happen, since at least 2 bytes exist.
+ // Should *NEVER* happen, since the length is checked.
panic(fmt.Sprintf("BUG: Failed to read secretbox, got %d, should have %d",
n, decoder.nextLength))
}
- out, ok := secretbox.Open(nil, box, &decoder.nextNonce, &decoder.key)
+ out, ok := secretbox.Open(data[:0], box[:n], &decoder.nextNonce, &decoder.key)
if !ok {
- return 0, nil, ErrTagMismatch
+ return 0, ErrTagMismatch
}
- decoder.sip.Write(box)
+ decoder.sip.Write(box[:n])
// Clean up and prepare for the next frame.
decoder.nextLength = 0
decoder.nonce.counter++
- return len(out), out, nil
+ return len(out), nil
}
/* vim :set ts=4 sw=4 sts=4 noet : */
diff --git a/framing/framing_test.go b/framing/framing_test.go
index 221ea5e..08f5f17 100644
--- a/framing/framing_test.go
+++ b/framing/framing_test.go
@@ -69,7 +69,8 @@ func TestEncoder_Encode(t *testing.T) {
buf := make([]byte, MaximumFramePayloadLength)
_, _ = rand.Read(buf) // YOLO
for i := 0; i <= MaximumFramePayloadLength; i++ {
- n, frame, err := encoder.Encode(buf[0:i])
+ var frame [MaximumSegmentLength]byte
+ n, err := encoder.Encode(frame[:], buf[0:i])
if err != nil {
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
}
@@ -77,10 +78,6 @@ func TestEncoder_Encode(t *testing.T) {
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", n, i+
FrameOverhead)
}
- if len(frame) != n {
- t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
- len(frame), n)
- }
}
}
@@ -88,9 +85,10 @@ func TestEncoder_Encode(t *testing.T) {
func TestEncoder_Encode_Oversize(t *testing.T) {
encoder := newEncoder(t)
- buf := make([]byte, MaximumFramePayloadLength+1)
- _, _ = rand.Read(buf) // YOLO
- _, _, err := encoder.Encode(buf)
+ var frame [MaximumSegmentLength]byte
+ var buf [MaximumFramePayloadLength+1]byte
+ _, _ = rand.Read(buf[:]) // YOLO
+ _, err := encoder.Encode(frame[:], buf[:])
if _, ok := err.(InvalidPayloadLengthError); !ok {
t.Error("Encoder.encode() returned unexpected error:", err)
}
@@ -112,10 +110,11 @@ func TestDecoder_Decode(t *testing.T) {
encoder := NewEncoder(key)
decoder := NewDecoder(key)
- buf := make([]byte, MaximumFramePayloadLength)
- _, _ = rand.Read(buf) // YOLO
+ var buf [MaximumFramePayloadLength]byte
+ _, _ = rand.Read(buf[:]) // YOLO
for i := 0; i <= MaximumFramePayloadLength; i++ {
- encLen, frame, err := encoder.Encode(buf[0:i])
+ var frame [MaximumSegmentLength]byte
+ encLen, err := encoder.Encode(frame[:], buf[0:i])
if err != nil {
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
}
@@ -123,12 +122,10 @@ func TestDecoder_Decode(t *testing.T) {
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", encLen,
i+FrameOverhead)
}
- if len(frame) != encLen {
- t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
- len(frame), encLen)
- }
- decLen, decoded, err := decoder.Decode(bytes.NewBuffer(frame))
+ var decoded [MaximumFramePayloadLength]byte
+
+ decLen, err := decoder.Decode(decoded[:], bytes.NewBuffer(frame[:encLen]))
if err != nil {
t.Fatalf("Decoder.decode([%d]byte), failed: %s", i, err)
}
@@ -136,13 +133,8 @@ func TestDecoder_Decode(t *testing.T) {
t.Fatalf("Unexpected decoded framesize: %d, expecting %d",
decLen, i)
}
- if len(decoded) != i {
- t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
- len(decoded), i)
-
- }
- if 0 != bytes.Compare(decoded, buf[0:i]) {
+ if 0 != bytes.Compare(decoded[:decLen], buf[:i]) {
t.Fatalf("Frame %d does not match encoder input", i)
}
}
@@ -152,6 +144,7 @@ func TestDecoder_Decode(t *testing.T) {
// of payload.
func BenchmarkEncoder_Encode(b *testing.B) {
var chopBuf [MaximumFramePayloadLength]byte
+ var frame [MaximumSegmentLength]byte
payload := make([]byte, 1024*1024)
encoder := NewEncoder(generateRandomKey())
b.ResetTimer()
@@ -165,8 +158,8 @@ func BenchmarkEncoder_Encode(b *testing.B) {
b.Fatal("buffer.Read() failed:", err)
}
- n, frame, err := encoder.Encode(chopBuf[:n])
- transfered += len(frame) - FrameOverhead
+ n, err = encoder.Encode(frame[:], chopBuf[:n])
+ transfered += n - FrameOverhead
}
if transfered != len(payload) {
b.Fatalf("Transfered length mismatch: %d != %d", transfered,
diff --git a/packet.go b/packet.go
index 339a86d..6f9eb03 100644
--- a/packet.go
+++ b/packet.go
@@ -100,18 +100,18 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
pktLen := packetOverhead + len(data) + int(padLen)
// Encode the packet in an AEAD frame.
- // TODO: Change Encode to write into frame directly
- var frame []byte
- _, frame, err = c.encoder.Encode(pkt[:pktLen])
+ var frame [framing.MaximumSegmentLength]byte
+ frameLen := 0
+ frameLen, err = c.encoder.Encode(frame[:], pkt[:pktLen])
if err != nil {
// All encoder errors are fatal.
return
}
var wrLen int
- wrLen, err = w.Write(frame)
+ wrLen, err = w.Write(frame[:frameLen])
if err != nil {
return
- } else if wrLen < len(frame) {
+ } else if wrLen < frameLen {
err = io.ErrShortWrite
return
}
@@ -132,22 +132,23 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
}
c.receiveBuffer.Write(buf[:rdLen])
+ var decoded [framing.MaximumFramePayloadLength]byte
for c.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame.
- // TODO: Change Decode to write into packet directly
- var pkt []byte
- _, pkt, err = c.decoder.Decode(&c.receiveBuffer)
+ decLen := 0
+ decLen, err = c.decoder.Decode(decoded[:], &c.receiveBuffer)
if err == framing.ErrAgain {
// The accumulated payload does not make up a full frame.
return
} else if err != nil {
break
- } else if len(pkt) < packetOverhead {
- err = InvalidPacketLengthError(len(pkt))
+ } else if decLen < packetOverhead {
+ err = InvalidPacketLengthError(decLen)
break
}
// Decode the packet.
+ pkt := decoded[0:decLen]
pktType := pkt[0]
payloadLen := binary.BigEndian.Uint16(pkt[1:])
if int(payloadLen) > len(pkt)-packetOverhead {