summaryrefslogtreecommitdiff
path: root/packet.go
diff options
context:
space:
mode:
Diffstat (limited to 'packet.go')
-rw-r--r--packet.go40
1 files changed, 26 insertions, 14 deletions
diff --git a/packet.go b/packet.go
index 7b69517..339a86d 100644
--- a/packet.go
+++ b/packet.go
@@ -67,14 +67,24 @@ func (e InvalidPayloadLengthError) Error() string {
var zeroPadBytes [maxPacketPaddingLength]byte
-func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error {
+func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) {
var pkt [framing.MaximumFramePayloadLength]byte
+ if !c.CanReadWrite() {
+ return syscall.EINVAL
+ }
+
if len(data)+int(padLen) > maxPacketPayloadLength {
panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d",
len(data), padLen, maxPacketPayloadLength))
}
+ defer func() {
+ if err != nil {
+ c.setBroken()
+ }
+ }()
+
// Packets are:
// uint8_t type packetTypePayload (0x00)
// uint16_t length Length of the payload (Big Endian).
@@ -91,31 +101,32 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
// Encode the packet in an AEAD frame.
// TODO: Change Encode to write into frame directly
- _, frame, err := c.encoder.Encode(pkt[:pktLen])
+ var frame []byte
+ _, frame, err = c.encoder.Encode(pkt[:pktLen])
if err != nil {
// All encoder errors are fatal.
- c.isOk = false
- return err
+ return
}
- wrLen, err := w.Write(frame)
+ var wrLen int
+ wrLen, err = w.Write(frame)
if err != nil {
- c.isOk = false
- return err
+ return
} else if wrLen < len(frame) {
- c.isOk = false
- return io.ErrShortWrite
+ err = io.ErrShortWrite
+ return
}
- return nil
+ return
}
func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
- if !c.isOk {
+ if !c.CanReadWrite() {
return n, syscall.EINVAL
}
var buf [consumeReadSize]byte
- rdLen, err := c.conn.Read(buf[:])
+ var rdLen int
+ rdLen, err = c.conn.Read(buf[:])
if err != nil {
return
}
@@ -150,7 +161,8 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
if payloadLen > 0 {
if w != nil {
// c.WriteTo() skips buffering in c.receiveDecodedBuffer
- wrLen, err := w.Write(payload)
+ var wrLen int
+ wrLen, err = w.Write(payload)
n += wrLen
if err != nil {
break
@@ -176,7 +188,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
// All errors that reach this point are fatal.
if err != nil {
- c.isOk = false
+ c.setBroken()
}
return