summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--obfs4.go145
-rw-r--r--packet.go40
2 files changed, 119 insertions, 66 deletions
diff --git a/obfs4.go b/obfs4.go
index afe8967..eadcbef 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -51,6 +51,15 @@ const (
maxCloseInterval = 60
)
+type connState int
+
+const (
+ stateInit connState = iota
+ stateEstablished
+ stateBroken
+ stateClosed
+)
+
// Obfs4Conn is the implementation of the net.Conn interface for obfs4
// connections.
type Obfs4Conn struct {
@@ -64,7 +73,7 @@ type Obfs4Conn struct {
receiveBuffer bytes.Buffer
receiveDecodedBuffer bytes.Buffer
- isOk bool
+ state connState
isServer bool
// Server side state.
@@ -111,51 +120,65 @@ func (c *Obfs4Conn) closeAfterDelay() {
}
}
-func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) error {
+func (c *Obfs4Conn) setBroken() {
+ c.state = stateBroken
+}
+
+func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) (err error) {
if c.isServer {
panic(fmt.Sprintf("BUG: clientHandshake() called for server connection"))
}
+ defer func() {
+ if err != nil {
+ c.setBroken()
+ }
+ }()
+
// Generate/send the client handshake.
- hs, err := newClientHandshake(nodeID, publicKey)
+ var hs *clientHandshake
+ var blob []byte
+ hs, err = newClientHandshake(nodeID, publicKey)
if err != nil {
- return err
+ return
}
- blob, err := hs.generateHandshake()
+ blob, err = hs.generateHandshake()
if err != nil {
- return err
+ return
}
err = c.conn.SetDeadline(time.Now().Add(connectionTimeout * 2))
if err != nil {
- return err
+ return
}
_, err = c.conn.Write(blob)
if err != nil {
- return err
+ return
}
// Consume the server handshake.
- hsBuf := make([]byte, serverMaxHandshakeLength)
+ var hsBuf [serverMaxHandshakeLength]byte
for {
- n, err := c.conn.Read(hsBuf)
+ var n int
+ n, err = c.conn.Read(hsBuf[:])
if err != nil {
- return err
+ return
}
c.receiveBuffer.Write(hsBuf[:n])
- n, seed, err := hs.parseServerHandshake(c.receiveBuffer.Bytes())
+ var seed []byte
+ n, seed, err = hs.parseServerHandshake(c.receiveBuffer.Bytes())
if err == ErrMarkNotFoundYet {
continue
} else if err != nil {
- return err
+ return
}
_ = c.receiveBuffer.Next(n)
err = c.conn.SetDeadline(time.Time{})
if err != nil {
- return err
+ return
}
// Use the derived key material to intialize the link crypto.
@@ -163,37 +186,45 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
c.encoder = framing.NewEncoder(okm[:framing.KeyLength])
c.decoder = framing.NewDecoder(okm[framing.KeyLength:])
- c.isOk = true
+ c.state = stateEstablished
return nil
}
}
-func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) error {
+func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) (err error) {
if !c.isServer {
panic(fmt.Sprintf("BUG: serverHandshake() called for client connection"))
}
+ defer func() {
+ if err != nil {
+ c.setBroken()
+ }
+ }()
+
hs := newServerHandshake(nodeID, keypair)
- err := c.conn.SetDeadline(time.Now().Add(connectionTimeout))
+ err = c.conn.SetDeadline(time.Now().Add(connectionTimeout))
if err != nil {
- return err
+ return
}
// Consume the client handshake.
- hsBuf := make([]byte, clientMaxHandshakeLength)
+ var hsBuf [clientMaxHandshakeLength]byte
for {
- n, err := c.conn.Read(hsBuf)
+ var n int
+ n, err = c.conn.Read(hsBuf[:])
if err != nil {
- return err
+ return
}
c.receiveBuffer.Write(hsBuf[:n])
- seed, err := hs.parseClientHandshake(c.receiveBuffer.Bytes())
+ var seed []byte
+ seed, err = hs.parseClientHandshake(c.receiveBuffer.Bytes())
if err == ErrMarkNotFoundYet {
continue
} else if err != nil {
- return err
+ return
}
c.receiveBuffer.Reset()
@@ -206,46 +237,51 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair)
}
// Generate/send the response.
- blob, err := hs.generateHandshake()
+ var blob []byte
+ blob, err = hs.generateHandshake()
if err != nil {
- return err
+ return
}
_, err = c.conn.Write(blob)
if err != nil {
- return err
+ return
}
err = c.conn.SetDeadline(time.Time{})
if err != nil {
- return err
+ return
}
- c.isOk = true
+ c.state = stateEstablished
// TODO: Generate/send the PRNG seed.
return nil
}
+func (c *Obfs4Conn) CanHandshake() bool {
+ return c.state == stateInit
+}
+
+func (c *Obfs4Conn) CanReadWrite() bool {
+ return c.state == stateEstablished
+}
+
func (c *Obfs4Conn) ServerHandshake() error {
// Handshakes when already established are a no-op.
- if c.isOk {
+ if c.CanReadWrite() {
return nil
+ } else if !c.CanHandshake() {
+ return syscall.EINVAL
}
- // Clients handshake as part of Dial.
if !c.isServer {
panic(fmt.Sprintf("BUG: ServerHandshake() called for client connection"))
}
- // Regardless of what happens, don't need the listener past returning from
- // this routine.
- defer func() {
- c.listener = nil
- }()
-
// Complete the handshake.
err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair)
+ c.listener = nil
if err != nil {
c.closeAfterDelay()
}
@@ -254,7 +290,7 @@ func (c *Obfs4Conn) ServerHandshake() error {
}
func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
- if !c.isOk {
+ if !c.CanReadWrite() {
return 0, syscall.EINVAL
}
@@ -272,20 +308,19 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) {
}
func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
- if !c.isOk {
+ if !c.CanReadWrite() {
return 0, syscall.EINVAL
}
- wrLen := 0
-
// If there is buffered payload from earlier Read() calls, write.
+ wrLen := 0
if c.receiveDecodedBuffer.Len() > 0 {
wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes())
if err != nil {
- c.isOk = false
+ c.setBroken()
return int64(wrLen), err
} else if wrLen < int(c.receiveDecodedBuffer.Len()) {
- c.isOk = false
+ c.setBroken()
return int64(wrLen), io.ErrShortWrite
}
c.receiveDecodedBuffer.Reset()
@@ -309,6 +344,17 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
}
func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
+ if !c.CanReadWrite() {
+ return 0, syscall.EINVAL
+ }
+
+ defer func() {
+ if err != nil {
+ c.setBroken()
+ }
+ }()
+
+ // XXX: Change this to write directly to c.conn skipping frameBuf.
chopBuf := bytes.NewBuffer(b)
var payload [maxPacketPayloadLength]byte
var frameBuf bytes.Buffer
@@ -318,7 +364,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
rdLen := 0
rdLen, err = chopBuf.Read(payload[:])
if err != nil {
- c.isOk = false
return 0, err
} else if rdLen == 0 {
panic(fmt.Sprintf("BUG: Write(), chopping length was 0"))
@@ -327,7 +372,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0)
if err != nil {
- c.isOk = false
return 0, err
}
}
@@ -340,20 +384,17 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
uint16(padLen-headerLength))
if err != nil {
- c.isOk = false
return 0, err
}
} else if padLen > 0 {
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
maxPacketPayloadLength)
if err != nil {
- c.isOk = false
return 0, err
}
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
uint16(padLen))
if err != nil {
- c.isOk = false
return 0, err
}
}
@@ -364,7 +405,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
// Partial writes are fatal because the frame encoder state is advanced
// at this point. It's possible to keep frameBuf around, but fuck it.
// Someone that wants write timeouts can change this.
- c.isOk = false
return 0, err
}
@@ -376,13 +416,13 @@ func (c *Obfs4Conn) Close() error {
return syscall.EINVAL
}
- c.isOk = false
+ c.state = stateClosed
return c.conn.Close()
}
func (c *Obfs4Conn) LocalAddr() net.Addr {
- if !c.isOk {
+ if c.state == stateClosed {
return nil
}
@@ -390,7 +430,7 @@ func (c *Obfs4Conn) LocalAddr() net.Addr {
}
func (c *Obfs4Conn) RemoteAddr() net.Addr {
- if !c.isOk {
+ if c.state == stateClosed {
return nil
}
@@ -402,7 +442,7 @@ func (c *Obfs4Conn) SetDeadline(t time.Time) error {
}
func (c *Obfs4Conn) SetReadDeadline(t time.Time) error {
- if !c.isOk {
+ if !c.CanReadWrite() {
return syscall.EINVAL
}
@@ -487,6 +527,7 @@ func (l *Obfs4Listener) PublicKey() string {
if l.keyPair == nil {
return ""
}
+
return l.keyPair.Public().Base64()
}
diff --git a/packet.go b/packet.go
index 7b69517..339a86d 100644
--- a/packet.go
+++ b/packet.go
@@ -67,14 +67,24 @@ func (e InvalidPayloadLengthError) Error() string {
var zeroPadBytes [maxPacketPaddingLength]byte
-func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error {
+func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) {
var pkt [framing.MaximumFramePayloadLength]byte
+ if !c.CanReadWrite() {
+ return syscall.EINVAL
+ }
+
if len(data)+int(padLen) > maxPacketPayloadLength {
panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d",
len(data), padLen, maxPacketPayloadLength))
}
+ defer func() {
+ if err != nil {
+ c.setBroken()
+ }
+ }()
+
// Packets are:
// uint8_t type packetTypePayload (0x00)
// uint16_t length Length of the payload (Big Endian).
@@ -91,31 +101,32 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
// Encode the packet in an AEAD frame.
// TODO: Change Encode to write into frame directly
- _, frame, err := c.encoder.Encode(pkt[:pktLen])
+ var frame []byte
+ _, frame, err = c.encoder.Encode(pkt[:pktLen])
if err != nil {
// All encoder errors are fatal.
- c.isOk = false
- return err
+ return
}
- wrLen, err := w.Write(frame)
+ var wrLen int
+ wrLen, err = w.Write(frame)
if err != nil {
- c.isOk = false
- return err
+ return
} else if wrLen < len(frame) {
- c.isOk = false
- return io.ErrShortWrite
+ err = io.ErrShortWrite
+ return
}
- return nil
+ return
}
func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
- if !c.isOk {
+ if !c.CanReadWrite() {
return n, syscall.EINVAL
}
var buf [consumeReadSize]byte
- rdLen, err := c.conn.Read(buf[:])
+ var rdLen int
+ rdLen, err = c.conn.Read(buf[:])
if err != nil {
return
}
@@ -150,7 +161,8 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
if payloadLen > 0 {
if w != nil {
// c.WriteTo() skips buffering in c.receiveDecodedBuffer
- wrLen, err := w.Write(payload)
+ var wrLen int
+ wrLen, err = w.Write(payload)
n += wrLen
if err != nil {
break
@@ -176,7 +188,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
// All errors that reach this point are fatal.
if err != nil {
- c.isOk = false
+ c.setBroken()
}
return