From 582aa3a366cb06bf793440f62fcf3d54b85d8270 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Wed, 14 May 2014 05:48:43 +0000 Subject: First pass at cleaning up the read code. --- packet.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 18 deletions(-) (limited to 'packet.go') diff --git a/packet.go b/packet.go index 756ee7a..6726ed4 100644 --- a/packet.go +++ b/packet.go @@ -30,6 +30,7 @@ package obfs4 import ( "encoding/binary" "fmt" + "syscall" "github.com/yawning/obfs4/framing" ) @@ -38,6 +39,8 @@ const ( packetOverhead = 2 + 1 maxPacketPayloadLength = framing.MaximumFramePayloadLength - packetOverhead maxPacketPaddingLength = maxPacketPayloadLength + + consumeReadSize = framing.MaximumSegmentLength * 16 ) const ( @@ -98,32 +101,61 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint return n, frame, err } -func (c *Obfs4Conn) decodePacket(pkt []byte) error { - if len(pkt) < packetOverhead { - return InvalidPacketLengthError(len(pkt)) +func (c *Obfs4Conn) consumeFramedPackets() (err error) { + if !c.isOk { + return syscall.EINVAL } - pktType := pkt[0] - payloadLen := binary.BigEndian.Uint16(pkt[1:]) - if int(payloadLen) > len(pkt)-packetOverhead { - return InvalidPayloadLengthError(int(payloadLen)) + var buf [consumeReadSize]byte + n, err := c.conn.Read(buf[:]) + if err != nil { + return } + c.receiveBuffer.Write(buf[:n]) + + for c.receiveBuffer.Len() > 0 { + // Decrypt an AEAD frame. + _, pkt, err := c.decoder.Decode(&c.receiveBuffer) + if err == framing.ErrAgain { + // The accumulated payload does not make up a full frame. + return nil + } else if err != nil { + break + } else if len(pkt) < packetOverhead { + err = InvalidPacketLengthError(len(pkt)) + break + } - payload := pkt[3 : 3+payloadLen] - switch pktType { - case packetTypePayload: - if len(payload) > 0 { - c.receiveDecodedBuffer.Write(payload) + // Decode the packet. + pktType := pkt[0] + payloadLen := binary.BigEndian.Uint16(pkt[1:]) + if int(payloadLen) > len(pkt)-packetOverhead { + err = InvalidPayloadLengthError(int(payloadLen)) + break } - case packetTypePrngSeed: - if len(payload) == distSeedLength { - c.probDist.reset(payload) + payload := pkt[3 : 3+payloadLen] + + switch pktType { + case packetTypePayload: + if payloadLen > 0 { + c.receiveDecodedBuffer.Write(payload) + } + case packetTypePrngSeed: + // Only regenerate the distribution if we are the client. + if len(payload) >= distSeedLength && !c.isServer { + c.probDist.reset(payload[:distSeedLength]) + } + default: + // Ignore unrecognised packet types. } - default: - // Ignore unrecognised packet types. } - return nil + // All errors that reach this point are fatal. + if err != nil { + c.isOk = false + } + + return } /* vim :set ts=4 sw=4 sts=4 noet : */ -- cgit v1.2.3