summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-05-14 06:27:41 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-05-14 06:27:41 +0000
commit731a926172dfd043a00ec2015f07af8f4970f7f1 (patch)
tree70c131ce0b5b25a013de6d2f8b3b671727cd9825
parent582aa3a366cb06bf793440f62fcf3d54b85d8270 (diff)
Implement the io.WriterTo interface.
-rw-r--r--obfs4.go52
-rw-r--r--packet.go33
2 files changed, 71 insertions, 14 deletions
diff --git a/obfs4.go b/obfs4.go
index f2a0139..2823b75 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -31,6 +31,7 @@ package obfs4
import (
"bytes"
"fmt"
+ "io"
"net"
"syscall"
"time"
@@ -55,7 +56,7 @@ const (
type Obfs4Conn struct {
conn net.Conn
- probDist *wDist
+ lenProbDist *wDist
encoder *framing.Encoder
decoder *framing.Decoder
@@ -72,7 +73,7 @@ type Obfs4Conn struct {
func (c *Obfs4Conn) calcPadLen(burstLen int) int {
tailLen := burstLen % framing.MaximumSegmentLength
- toPadTo := c.probDist.sample()
+ toPadTo := c.lenProbDist.sample()
ret := 0
if toPadTo >= tailLen {
@@ -258,8 +259,10 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
}
for c.receiveDecodedBuffer.Len() == 0 {
- err = c.consumeFramedPackets()
- if err != nil {
+ _, err = c.consumeFramedPackets(nil)
+ if err == framing.ErrAgain {
+ continue
+ } else if err != nil {
return
}
}
@@ -268,6 +271,43 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
return
}
+func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
+ if !c.isOk {
+ return 0, syscall.EINVAL
+ }
+
+ wrLen := 0
+
+ // If there is buffered payload from earlier Read() calls, write.
+ if c.receiveDecodedBuffer.Len() > 0 {
+ wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes())
+ if wrLen < int(c.receiveDecodedBuffer.Len()) {
+ c.isOk = false
+ return int64(wrLen), io.ErrShortWrite
+ } else if err != nil {
+ c.isOk = false
+ return int64(wrLen), err
+ }
+ c.receiveDecodedBuffer.Reset()
+ }
+
+ for {
+ wrLen, err = c.consumeFramedPackets(w)
+ n += int64(wrLen)
+ if err == framing.ErrAgain {
+ continue
+ } else if err != nil {
+ // io.EOF is treated as not an error.
+ if err == io.EOF {
+ err = nil
+ }
+ break
+ }
+ }
+
+ return
+}
+
func (c *Obfs4Conn) Write(b []byte) (int, error) {
chopBuf := bytes.NewBuffer(b)
buf := make([]byte, maxPacketPayloadLength)
@@ -394,7 +434,7 @@ func Dial(network, address, nodeID, publicKey string) (net.Conn, error) {
// Connect to the peer.
c := new(Obfs4Conn)
- c.probDist, err = newWDist(nil, 0, framing.MaximumSegmentLength)
+ c.lenProbDist, err = newWDist(nil, 0, framing.MaximumSegmentLength)
if err != nil {
return nil, err
}
@@ -434,7 +474,7 @@ func (l *Obfs4Listener) Accept() (net.Conn, error) {
cObfs.conn = c
cObfs.isServer = true
cObfs.listener = l
- cObfs.probDist, err = newWDist(nil, 0, framing.MaximumSegmentLength)
+ cObfs.lenProbDist, err = newWDist(nil, 0, framing.MaximumSegmentLength)
if err != nil {
c.Close()
return nil, err
diff --git a/packet.go b/packet.go
index 6726ed4..8fb53d0 100644
--- a/packet.go
+++ b/packet.go
@@ -30,6 +30,7 @@ package obfs4
import (
"encoding/binary"
"fmt"
+ "io"
"syscall"
"github.com/yawning/obfs4/framing"
@@ -101,24 +102,26 @@ func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint
return n, frame, err
}
-func (c *Obfs4Conn) consumeFramedPackets() (err error) {
+func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
if !c.isOk {
- return syscall.EINVAL
+ return n, syscall.EINVAL
}
var buf [consumeReadSize]byte
- n, err := c.conn.Read(buf[:])
+ rdLen, err := c.conn.Read(buf[:])
if err != nil {
return
}
- c.receiveBuffer.Write(buf[:n])
+ c.receiveBuffer.Write(buf[:rdLen])
for c.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame.
- _, pkt, err := c.decoder.Decode(&c.receiveBuffer)
+ // TODO: Change decode to write into packet directly
+ var pkt []byte
+ _, pkt, err = c.decoder.Decode(&c.receiveBuffer)
if err == framing.ErrAgain {
// The accumulated payload does not make up a full frame.
- return nil
+ return
} else if err != nil {
break
} else if len(pkt) < packetOverhead {
@@ -138,12 +141,26 @@ func (c *Obfs4Conn) consumeFramedPackets() (err error) {
switch pktType {
case packetTypePayload:
if payloadLen > 0 {
- c.receiveDecodedBuffer.Write(payload)
+ if w != nil {
+ // c.WriteTo() skips buffering in c.receiveDecodedBuffer
+ wrLen, err := w.Write(payload)
+ n += wrLen
+ if wrLen < int(payloadLen) {
+ err = io.ErrShortWrite
+ break
+ } else if err != nil {
+ break
+ }
+ } else {
+ // c.Read() stashes decoded payload in receiveDecodedBuffer
+ c.receiveDecodedBuffer.Write(payload)
+ n += int(payloadLen)
+ }
}
case packetTypePrngSeed:
// Only regenerate the distribution if we are the client.
if len(payload) >= distSeedLength && !c.isServer {
- c.probDist.reset(payload[:distSeedLength])
+ c.lenProbDist.reset(payload[:distSeedLength])
}
default:
// Ignore unrecognised packet types.