From 731a926172dfd043a00ec2015f07af8f4970f7f1 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Wed, 14 May 2014 06:27:41 +0000 Subject: Implement the io.WriterTo interface. --- obfs4.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++------ packet.go | 33 +++++++++++++++++++++++++-------- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/obfs4.go b/obfs4.go index f2a0139..2823b75 100644 --- a/obfs4.go +++ b/obfs4.go @@ -31,6 +31,7 @@ package obfs4 import ( "bytes" "fmt" + "io" "net" "syscall" "time" @@ -55,7 +56,7 @@ const ( type Obfs4Conn struct { conn net.Conn - probDist *wDist + lenProbDist *wDist encoder *framing.Encoder decoder *framing.Decoder @@ -72,7 +73,7 @@ type Obfs4Conn struct { func (c *Obfs4Conn) calcPadLen(burstLen int) int { tailLen := burstLen % framing.MaximumSegmentLength - toPadTo := c.probDist.sample() + toPadTo := c.lenProbDist.sample() ret := 0 if toPadTo >= tailLen { @@ -258,8 +259,10 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) { } for c.receiveDecodedBuffer.Len() == 0 { - err = c.consumeFramedPackets() - if err != nil { + _, err = c.consumeFramedPackets(nil) + if err == framing.ErrAgain { + continue + } else if err != nil { return } } @@ -268,6 +271,43 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) { return } +func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { + if !c.isOk { + return 0, syscall.EINVAL + } + + wrLen := 0 + + // If there is buffered payload from earlier Read() calls, write. + if c.receiveDecodedBuffer.Len() > 0 { + wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes()) + if wrLen < int(c.receiveDecodedBuffer.Len()) { + c.isOk = false + return int64(wrLen), io.ErrShortWrite + } else if err != nil { + c.isOk = false + return int64(wrLen), err + } + c.receiveDecodedBuffer.Reset() + } + + for { + wrLen, err = c.consumeFramedPackets(w) + n += int64(wrLen) + if err == framing.ErrAgain { + continue + } else if err != nil { + // io.EOF is treated as not an error. + if err == io.EOF { + err = nil + } + break + } + } + + return +} + func (c *Obfs4Conn) Write(b []byte) (int, error) { chopBuf := bytes.NewBuffer(b) buf := make([]byte, maxPacketPayloadLength) @@ -394,7 +434,7 @@ func Dial(network, address, nodeID, publicKey string) (net.Conn, error) { // Connect to the peer. c := new(Obfs4Conn) - c.probDist, err = newWDist(nil, 0, framing.MaximumSegmentLength) + c.lenProbDist, err = newWDist(nil, 0, framing.MaximumSegmentLength) if err != nil { return nil, err } @@ -434,7 +474,7 @@ func (l *Obfs4Listener) Accept() (net.Conn, error) { cObfs.conn = c cObfs.isServer = true cObfs.listener = l - cObfs.probDist, err = newWDist(nil, 0, framing.MaximumSegmentLength) + cObfs.lenProbDist, err = newWDist(nil, 0, framing.MaximumSegmentLength) if err != nil { c.Close() return nil, err diff --git a/packet.go b/packet.go index 6726ed4..8fb53d0 100644 --- a/packet.go +++ b/packet.go @@ -30,6 +30,7 @@ package obfs4 import ( "encoding/binary" "fmt" + "io" "syscall" "github.com/yawning/obfs4/framing" @@ -101,24 +102,26 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint return n, frame, err } -func (c *Obfs4Conn) consumeFramedPackets() (err error) { +func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { if !c.isOk { - return syscall.EINVAL + return n, syscall.EINVAL } var buf [consumeReadSize]byte - n, err := c.conn.Read(buf[:]) + rdLen, err := c.conn.Read(buf[:]) if err != nil { return } - c.receiveBuffer.Write(buf[:n]) + c.receiveBuffer.Write(buf[:rdLen]) for c.receiveBuffer.Len() > 0 { // Decrypt an AEAD frame. - _, pkt, err := c.decoder.Decode(&c.receiveBuffer) + // TODO: Change decode to write into packet directly + var pkt []byte + _, pkt, err = c.decoder.Decode(&c.receiveBuffer) if err == framing.ErrAgain { // The accumulated payload does not make up a full frame. - return nil + return } else if err != nil { break } else if len(pkt) < packetOverhead { @@ -138,12 +141,26 @@ func (c *Obfs4Conn) consumeFramedPackets() (err error) { switch pktType { case packetTypePayload: if payloadLen > 0 { - c.receiveDecodedBuffer.Write(payload) + if w != nil { + // c.WriteTo() skips buffering in c.receiveDecodedBuffer + wrLen, err := w.Write(payload) + n += wrLen + if wrLen < int(payloadLen) { + err = io.ErrShortWrite + break + } else if err != nil { + break + } + } else { + // c.Read() stashes decoded payload in receiveDecodedBuffer + c.receiveDecodedBuffer.Write(payload) + n += int(payloadLen) + } } case packetTypePrngSeed: // Only regenerate the distribution if we are the client. if len(payload) >= distSeedLength && !c.isServer { - c.probDist.reset(payload[:distSeedLength]) + c.lenProbDist.reset(payload[:distSeedLength]) } default: // Ignore unrecognised packet types. -- cgit v1.2.3