summaryrefslogtreecommitdiff
path: root/packet.go
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-05-14 05:48:43 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-05-14 05:48:43 +0000
commit582aa3a366cb06bf793440f62fcf3d54b85d8270 (patch)
treea433c0af292c4b345c33612b8bcbc583c3a26f93 /packet.go
parentece57277dbfa4aae86745f337229b779a9c298c2 (diff)
First pass at cleaning up the read code.
Diffstat (limited to 'packet.go')
-rw-r--r--packet.go68
1 files changed, 50 insertions, 18 deletions
diff --git a/packet.go b/packet.go
index 756ee7a..6726ed4 100644
--- a/packet.go
+++ b/packet.go
@@ -30,6 +30,7 @@ package obfs4
import (
"encoding/binary"
"fmt"
+ "syscall"
"github.com/yawning/obfs4/framing"
)
@@ -38,6 +39,8 @@ const (
packetOverhead = 2 + 1
maxPacketPayloadLength = framing.MaximumFramePayloadLength - packetOverhead
maxPacketPaddingLength = maxPacketPayloadLength
+
+ consumeReadSize = framing.MaximumSegmentLength * 16
)
const (
@@ -98,32 +101,61 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint
return n, frame, err
}
-func (c *Obfs4Conn) decodePacket(pkt []byte) error {
- if len(pkt) < packetOverhead {
- return InvalidPacketLengthError(len(pkt))
+func (c *Obfs4Conn) consumeFramedPackets() (err error) {
+ if !c.isOk {
+ return syscall.EINVAL
}
- pktType := pkt[0]
- payloadLen := binary.BigEndian.Uint16(pkt[1:])
- if int(payloadLen) > len(pkt)-packetOverhead {
- return InvalidPayloadLengthError(int(payloadLen))
+ var buf [consumeReadSize]byte
+ n, err := c.conn.Read(buf[:])
+ if err != nil {
+ return
}
+ c.receiveBuffer.Write(buf[:n])
+
+ for c.receiveBuffer.Len() > 0 {
+ // Decrypt an AEAD frame.
+ _, pkt, err := c.decoder.Decode(&c.receiveBuffer)
+ if err == framing.ErrAgain {
+ // The accumulated payload does not make up a full frame.
+ return nil
+ } else if err != nil {
+ break
+ } else if len(pkt) < packetOverhead {
+ err = InvalidPacketLengthError(len(pkt))
+ break
+ }
- payload := pkt[3 : 3+payloadLen]
- switch pktType {
- case packetTypePayload:
- if len(payload) > 0 {
- c.receiveDecodedBuffer.Write(payload)
+ // Decode the packet.
+ pktType := pkt[0]
+ payloadLen := binary.BigEndian.Uint16(pkt[1:])
+ if int(payloadLen) > len(pkt)-packetOverhead {
+ err = InvalidPayloadLengthError(int(payloadLen))
+ break
}
- case packetTypePrngSeed:
- if len(payload) == distSeedLength {
- c.probDist.reset(payload)
+ payload := pkt[3 : 3+payloadLen]
+
+ switch pktType {
+ case packetTypePayload:
+ if payloadLen > 0 {
+ c.receiveDecodedBuffer.Write(payload)
+ }
+ case packetTypePrngSeed:
+ // Only regenerate the distribution if we are the client.
+ if len(payload) >= distSeedLength && !c.isServer {
+ c.probDist.reset(payload[:distSeedLength])
+ }
+ default:
+ // Ignore unrecognised packet types.
}
- default:
- // Ignore unrecognised packet types.
}
- return nil
+ // All errors that reach this point are fatal.
+ if err != nil {
+ c.isOk = false
+ }
+
+ return
}
/* vim :set ts=4 sw=4 sts=4 noet : */