From 731a926172dfd043a00ec2015f07af8f4970f7f1 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Wed, 14 May 2014 06:27:41 +0000 Subject: Implement the io.WriterTo interface. --- packet.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) (limited to 'packet.go') diff --git a/packet.go b/packet.go index 6726ed4..8fb53d0 100644 --- a/packet.go +++ b/packet.go @@ -30,6 +30,7 @@ package obfs4 import ( "encoding/binary" "fmt" + "io" "syscall" "github.com/yawning/obfs4/framing" @@ -101,24 +102,26 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint return n, frame, err } -func (c *Obfs4Conn) consumeFramedPackets() (err error) { +func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { if !c.isOk { - return syscall.EINVAL + return n, syscall.EINVAL } var buf [consumeReadSize]byte - n, err := c.conn.Read(buf[:]) + rdLen, err := c.conn.Read(buf[:]) if err != nil { return } - c.receiveBuffer.Write(buf[:n]) + c.receiveBuffer.Write(buf[:rdLen]) for c.receiveBuffer.Len() > 0 { // Decrypt an AEAD frame. - _, pkt, err := c.decoder.Decode(&c.receiveBuffer) + // TODO: Change decode to write into packet directly + var pkt []byte + _, pkt, err = c.decoder.Decode(&c.receiveBuffer) if err == framing.ErrAgain { // The accumulated payload does not make up a full frame. - return nil + return } else if err != nil { break } else if len(pkt) < packetOverhead { @@ -138,12 +141,26 @@ func (c *Obfs4Conn) consumeFramedPackets() (err error) { switch pktType { case packetTypePayload: if payloadLen > 0 { - c.receiveDecodedBuffer.Write(payload) + if w != nil { + // c.WriteTo() skips buffering in c.receiveDecodedBuffer + wrLen, err := w.Write(payload) + n += wrLen + if wrLen < int(payloadLen) { + err = io.ErrShortWrite + break + } else if err != nil { + break + } + } else { + // c.Read() stashes decoded payload in receiveDecodedBuffer + c.receiveDecodedBuffer.Write(payload) + n += int(payloadLen) + } } case packetTypePrngSeed: // Only regenerate the distribution if we are the client. if len(payload) >= distSeedLength && !c.isServer { - c.probDist.reset(payload[:distSeedLength]) + c.lenProbDist.reset(payload[:distSeedLength]) } default: // Ignore unrecognised packet types. -- cgit v1.2.3