summaryrefslogtreecommitdiff
path: root/packet.go
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-05-14 06:27:41 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-05-14 06:27:41 +0000
commit731a926172dfd043a00ec2015f07af8f4970f7f1 (patch)
tree70c131ce0b5b25a013de6d2f8b3b671727cd9825 /packet.go
parent582aa3a366cb06bf793440f62fcf3d54b85d8270 (diff)
Implement the io.WriterTo interface.
Diffstat (limited to 'packet.go')
-rw-r--r--packet.go33
1 files changed, 25 insertions, 8 deletions
diff --git a/packet.go b/packet.go
index 6726ed4..8fb53d0 100644
--- a/packet.go
+++ b/packet.go
@@ -30,6 +30,7 @@ package obfs4
import (
"encoding/binary"
"fmt"
+ "io"
"syscall"
"github.com/yawning/obfs4/framing"
@@ -101,24 +102,26 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint
return n, frame, err
}
-func (c *Obfs4Conn) consumeFramedPackets() (err error) {
+func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
if !c.isOk {
- return syscall.EINVAL
+ return n, syscall.EINVAL
}
var buf [consumeReadSize]byte
- n, err := c.conn.Read(buf[:])
+ rdLen, err := c.conn.Read(buf[:])
if err != nil {
return
}
- c.receiveBuffer.Write(buf[:n])
+ c.receiveBuffer.Write(buf[:rdLen])
for c.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame.
- _, pkt, err := c.decoder.Decode(&c.receiveBuffer)
+ // TODO: Change decode to write into packet directly
+ var pkt []byte
+ _, pkt, err = c.decoder.Decode(&c.receiveBuffer)
if err == framing.ErrAgain {
// The accumulated payload does not make up a full frame.
- return nil
+ return
} else if err != nil {
break
} else if len(pkt) < packetOverhead {
@@ -138,12 +141,26 @@ func (c *Obfs4Conn) consumeFramedPackets() (err error) {
switch pktType {
case packetTypePayload:
if payloadLen > 0 {
- c.receiveDecodedBuffer.Write(payload)
+ if w != nil {
+ // c.WriteTo() skips buffering in c.receiveDecodedBuffer
+ wrLen, err := w.Write(payload)
+ n += wrLen
+ if wrLen < int(payloadLen) {
+ err = io.ErrShortWrite
+ break
+ } else if err != nil {
+ break
+ }
+ } else {
+ // c.Read() stashes decoded payload in receiveDecodedBuffer
+ c.receiveDecodedBuffer.Write(payload)
+ n += int(payloadLen)
+ }
}
case packetTypePrngSeed:
// Only regenerate the distribution if we are the client.
if len(payload) >= distSeedLength && !c.isServer {
- c.probDist.reset(payload[:distSeedLength])
+ c.lenProbDist.reset(payload[:distSeedLength])
}
default:
// Ignore unrecognised packet types.