summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-05-14 05:48:43 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-05-14 05:48:43 +0000
commit582aa3a366cb06bf793440f62fcf3d54b85d8270 (patch)
treea433c0af292c4b345c33612b8bcbc583c3a26f93
parentece57277dbfa4aae86745f337229b779a9c298c2 (diff)
First pass at cleaning up the read code.
-rw-r--r--obfs4.go39
-rw-r--r--packet.go68
2 files changed, 55 insertions, 52 deletions
diff --git a/obfs4.go b/obfs4.go
index 3cb464e..f2a0139 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -252,49 +252,20 @@ func (c *Obfs4Conn) ServerHandshake() error {
return err
}
-func (c *Obfs4Conn) Read(b []byte) (int, error) {
+func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
if !c.isOk {
return 0, syscall.EINVAL
}
- if c.receiveDecodedBuffer.Len() > 0 {
- n, err := c.receiveDecodedBuffer.Read(b)
- return n, err
- }
-
- // Consume and decode frames off the network.
- buf := make([]byte, defaultReadSize)
for c.receiveDecodedBuffer.Len() == 0 {
- n, err := c.conn.Read(buf)
+ err = c.consumeFramedPackets()
if err != nil {
- return 0, err
- }
- c.receiveBuffer.Write(buf[:n])
-
- // Decode the data just read.
- for c.receiveBuffer.Len() > 0 {
- _, frame, err := c.decoder.Decode(&c.receiveBuffer)
- if err == framing.ErrAgain {
- break
- } else if err != nil {
- // Any other frame decoder errors are fatal.
- c.isOk = false
- return 0, err
- }
-
- // Decode the packet, if there is payload, it will be placed in
- // receiveDecodedBuffer automatically.
- err = c.decodePacket(frame)
- if err != nil {
- // All packet decoder errors are fatal.
- c.isOk = false
- return 0, err
- }
+ return
}
}
- n, err := c.receiveDecodedBuffer.Read(b)
- return n, err
+ n, err = c.receiveDecodedBuffer.Read(b)
+ return
}
func (c *Obfs4Conn) Write(b []byte) (int, error) {
diff --git a/packet.go b/packet.go
index 756ee7a..6726ed4 100644
--- a/packet.go
+++ b/packet.go
@@ -30,6 +30,7 @@ package obfs4
import (
"encoding/binary"
"fmt"
+ "syscall"
"github.com/yawning/obfs4/framing"
)
@@ -38,6 +39,8 @@ const (
packetOverhead = 2 + 1
maxPacketPayloadLength = framing.MaximumFramePayloadLength - packetOverhead
maxPacketPaddingLength = maxPacketPayloadLength
+
+ consumeReadSize = framing.MaximumSegmentLength * 16
)
const (
@@ -98,32 +101,61 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint
return n, frame, err
}
-func (c *Obfs4Conn) decodePacket(pkt []byte) error {
- if len(pkt) < packetOverhead {
- return InvalidPacketLengthError(len(pkt))
+func (c *Obfs4Conn) consumeFramedPackets() (err error) {
+ if !c.isOk {
+ return syscall.EINVAL
}
- pktType := pkt[0]
- payloadLen := binary.BigEndian.Uint16(pkt[1:])
- if int(payloadLen) > len(pkt)-packetOverhead {
- return InvalidPayloadLengthError(int(payloadLen))
+ var buf [consumeReadSize]byte
+ n, err := c.conn.Read(buf[:])
+ if err != nil {
+ return
}
+ c.receiveBuffer.Write(buf[:n])
+
+ for c.receiveBuffer.Len() > 0 {
+ // Decrypt an AEAD frame.
+ _, pkt, err := c.decoder.Decode(&c.receiveBuffer)
+ if err == framing.ErrAgain {
+ // The accumulated payload does not make up a full frame.
+ return nil
+ } else if err != nil {
+ break
+ } else if len(pkt) < packetOverhead {
+ err = InvalidPacketLengthError(len(pkt))
+ break
+ }
- payload := pkt[3 : 3+payloadLen]
- switch pktType {
- case packetTypePayload:
- if len(payload) > 0 {
- c.receiveDecodedBuffer.Write(payload)
+ // Decode the packet.
+ pktType := pkt[0]
+ payloadLen := binary.BigEndian.Uint16(pkt[1:])
+ if int(payloadLen) > len(pkt)-packetOverhead {
+ err = InvalidPayloadLengthError(int(payloadLen))
+ break
}
- case packetTypePrngSeed:
- if len(payload) == distSeedLength {
- c.probDist.reset(payload)
+ payload := pkt[3 : 3+payloadLen]
+
+ switch pktType {
+ case packetTypePayload:
+ if payloadLen > 0 {
+ c.receiveDecodedBuffer.Write(payload)
+ }
+ case packetTypePrngSeed:
+ // Only regenerate the distribution if we are the client.
+ if len(payload) >= distSeedLength && !c.isServer {
+ c.probDist.reset(payload[:distSeedLength])
+ }
+ default:
+ // Ignore unrecognised packet types.
}
- default:
- // Ignore unrecognised packet types.
}
- return nil
+ // All errors that reach this point are fatal.
+ if err != nil {
+ c.isOk = false
+ }
+
+ return
}
/* vim :set ts=4 sw=4 sts=4 noet : */