diff options
author | Yawning Angel <yawning@schwanenlied.me> | 2014-05-14 05:48:43 +0000 |
---|---|---|
committer | Yawning Angel <yawning@schwanenlied.me> | 2014-05-14 05:48:43 +0000 |
commit | 582aa3a366cb06bf793440f62fcf3d54b85d8270 (patch) | |
tree | a433c0af292c4b345c33612b8bcbc583c3a26f93 | |
parent | ece57277dbfa4aae86745f337229b779a9c298c2 (diff) |
First pass at cleaning up the read code.
-rw-r--r-- | obfs4.go | 39 | ||||
-rw-r--r-- | packet.go | 68 |
2 files changed, 55 insertions, 52 deletions
@@ -252,49 +252,20 @@ func (c *Obfs4Conn) ServerHandshake() error { return err } -func (c *Obfs4Conn) Read(b []byte) (int, error) { +func (c *Obfs4Conn) Read(b []byte) (n int, err error) { if !c.isOk { return 0, syscall.EINVAL } - if c.receiveDecodedBuffer.Len() > 0 { - n, err := c.receiveDecodedBuffer.Read(b) - return n, err - } - - // Consume and decode frames off the network. - buf := make([]byte, defaultReadSize) for c.receiveDecodedBuffer.Len() == 0 { - n, err := c.conn.Read(buf) + err = c.consumeFramedPackets() if err != nil { - return 0, err - } - c.receiveBuffer.Write(buf[:n]) - - // Decode the data just read. - for c.receiveBuffer.Len() > 0 { - _, frame, err := c.decoder.Decode(&c.receiveBuffer) - if err == framing.ErrAgain { - break - } else if err != nil { - // Any other frame decoder errors are fatal. - c.isOk = false - return 0, err - } - - // Decode the packet, if there is payload, it will be placed in - // receiveDecodedBuffer automatically. - err = c.decodePacket(frame) - if err != nil { - // All packet decoder errors are fatal. - c.isOk = false - return 0, err - } + return } } - n, err := c.receiveDecodedBuffer.Read(b) - return n, err + n, err = c.receiveDecodedBuffer.Read(b) + return } func (c *Obfs4Conn) Write(b []byte) (int, error) { @@ -30,6 +30,7 @@ package obfs4 import ( "encoding/binary" "fmt" + "syscall" "github.com/yawning/obfs4/framing" ) @@ -38,6 +39,8 @@ const ( packetOverhead = 2 + 1 maxPacketPayloadLength = framing.MaximumFramePayloadLength - packetOverhead maxPacketPaddingLength = maxPacketPayloadLength + + consumeReadSize = framing.MaximumSegmentLength * 16 ) const ( @@ -98,32 +101,61 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint return n, frame, err } -func (c *Obfs4Conn) decodePacket(pkt []byte) error { - if len(pkt) < packetOverhead { - return InvalidPacketLengthError(len(pkt)) +func (c *Obfs4Conn) consumeFramedPackets() (err error) { + if !c.isOk { + return syscall.EINVAL } - pktType := pkt[0] - payloadLen := binary.BigEndian.Uint16(pkt[1:]) - if int(payloadLen) > len(pkt)-packetOverhead { - return InvalidPayloadLengthError(int(payloadLen)) + var buf [consumeReadSize]byte + n, err := c.conn.Read(buf[:]) + if err != nil { + return } + c.receiveBuffer.Write(buf[:n]) + + for c.receiveBuffer.Len() > 0 { + // Decrypt an AEAD frame. + _, pkt, err := c.decoder.Decode(&c.receiveBuffer) + if err == framing.ErrAgain { + // The accumulated payload does not make up a full frame. + return nil + } else if err != nil { + break + } else if len(pkt) < packetOverhead { + err = InvalidPacketLengthError(len(pkt)) + break + } - payload := pkt[3 : 3+payloadLen] - switch pktType { - case packetTypePayload: - if len(payload) > 0 { - c.receiveDecodedBuffer.Write(payload) + // Decode the packet. + pktType := pkt[0] + payloadLen := binary.BigEndian.Uint16(pkt[1:]) + if int(payloadLen) > len(pkt)-packetOverhead { + err = InvalidPayloadLengthError(int(payloadLen)) + break } - case packetTypePrngSeed: - if len(payload) == distSeedLength { - c.probDist.reset(payload) + payload := pkt[3 : 3+payloadLen] + + switch pktType { + case packetTypePayload: + if payloadLen > 0 { + c.receiveDecodedBuffer.Write(payload) + } + case packetTypePrngSeed: + // Only regenerate the distribution if we are the client. + if len(payload) >= distSeedLength && !c.isServer { + c.probDist.reset(payload[:distSeedLength]) + } + default: + // Ignore unrecognised packet types. } - default: - // Ignore unrecognised packet types. } - return nil + // All errors that reach this point are fatal. + if err != nil { + c.isOk = false + } + + return } /* vim :set ts=4 sw=4 sts=4 noet : */ |