diff options
Diffstat (limited to 'packet.go')
-rw-r--r-- | packet.go | 40 |
1 files changed, 26 insertions, 14 deletions
@@ -67,14 +67,24 @@ func (e InvalidPayloadLengthError) Error() string { var zeroPadBytes [maxPacketPaddingLength]byte -func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error { +func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) { var pkt [framing.MaximumFramePayloadLength]byte + if !c.CanReadWrite() { + return syscall.EINVAL + } + if len(data)+int(padLen) > maxPacketPayloadLength { panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d", len(data), padLen, maxPacketPayloadLength)) } + defer func() { + if err != nil { + c.setBroken() + } + }() + // Packets are: // uint8_t type packetTypePayload (0x00) // uint16_t length Length of the payload (Big Endian). @@ -91,31 +101,32 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe // Encode the packet in an AEAD frame. // TODO: Change Encode to write into frame directly - _, frame, err := c.encoder.Encode(pkt[:pktLen]) + var frame []byte + _, frame, err = c.encoder.Encode(pkt[:pktLen]) if err != nil { // All encoder errors are fatal. - c.isOk = false - return err + return } - wrLen, err := w.Write(frame) + var wrLen int + wrLen, err = w.Write(frame) if err != nil { - c.isOk = false - return err + return } else if wrLen < len(frame) { - c.isOk = false - return io.ErrShortWrite + err = io.ErrShortWrite + return } - return nil + return } func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { - if !c.isOk { + if !c.CanReadWrite() { return n, syscall.EINVAL } var buf [consumeReadSize]byte - rdLen, err := c.conn.Read(buf[:]) + var rdLen int + rdLen, err = c.conn.Read(buf[:]) if err != nil { return } @@ -150,7 +161,8 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { if payloadLen > 0 { if w != nil { // c.WriteTo() skips buffering in c.receiveDecodedBuffer - wrLen, err := w.Write(payload) + var wrLen int + wrLen, err = w.Write(payload) n += wrLen if err != nil { break @@ -176,7 +188,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { // All errors that reach this point are fatal. if err != nil { - c.isOk = false + c.setBroken() } return |