From ded3f6948cc572cd57035a093bc884b56d939463 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Wed, 14 May 2014 08:06:27 +0000 Subject: Kill Obfs4Conn.isOk with fire, and replace it with a state var. --- obfs4.go | 145 ++++++++++++++++++++++++++++++++++++++++---------------------- packet.go | 40 +++++++++++------ 2 files changed, 119 insertions(+), 66 deletions(-) diff --git a/obfs4.go b/obfs4.go index afe8967..eadcbef 100644 --- a/obfs4.go +++ b/obfs4.go @@ -51,6 +51,15 @@ const ( maxCloseInterval = 60 ) +type connState int + +const ( + stateInit connState = iota + stateEstablished + stateBroken + stateClosed +) + // Obfs4Conn is the implementation of the net.Conn interface for obfs4 // connections. type Obfs4Conn struct { @@ -64,7 +73,7 @@ type Obfs4Conn struct { receiveBuffer bytes.Buffer receiveDecodedBuffer bytes.Buffer - isOk bool + state connState isServer bool // Server side state. @@ -111,51 +120,65 @@ func (c *Obfs4Conn) closeAfterDelay() { } } -func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) error { +func (c *Obfs4Conn) setBroken() { + c.state = stateBroken +} + +func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) (err error) { if c.isServer { panic(fmt.Sprintf("BUG: clientHandshake() called for server connection")) } + defer func() { + if err != nil { + c.setBroken() + } + }() + // Generate/send the client handshake. - hs, err := newClientHandshake(nodeID, publicKey) + var hs *clientHandshake + var blob []byte + hs, err = newClientHandshake(nodeID, publicKey) if err != nil { - return err + return } - blob, err := hs.generateHandshake() + blob, err = hs.generateHandshake() if err != nil { - return err + return } err = c.conn.SetDeadline(time.Now().Add(connectionTimeout * 2)) if err != nil { - return err + return } _, err = c.conn.Write(blob) if err != nil { - return err + return } // Consume the server handshake. - hsBuf := make([]byte, serverMaxHandshakeLength) + var hsBuf [serverMaxHandshakeLength]byte for { - n, err := c.conn.Read(hsBuf) + var n int + n, err = c.conn.Read(hsBuf[:]) if err != nil { - return err + return } c.receiveBuffer.Write(hsBuf[:n]) - n, seed, err := hs.parseServerHandshake(c.receiveBuffer.Bytes()) + var seed []byte + n, seed, err = hs.parseServerHandshake(c.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { - return err + return } _ = c.receiveBuffer.Next(n) err = c.conn.SetDeadline(time.Time{}) if err != nil { - return err + return } // Use the derived key material to intialize the link crypto. @@ -163,37 +186,45 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK c.encoder = framing.NewEncoder(okm[:framing.KeyLength]) c.decoder = framing.NewDecoder(okm[framing.KeyLength:]) - c.isOk = true + c.state = stateEstablished return nil } } -func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) error { +func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) (err error) { if !c.isServer { panic(fmt.Sprintf("BUG: serverHandshake() called for client connection")) } + defer func() { + if err != nil { + c.setBroken() + } + }() + hs := newServerHandshake(nodeID, keypair) - err := c.conn.SetDeadline(time.Now().Add(connectionTimeout)) + err = c.conn.SetDeadline(time.Now().Add(connectionTimeout)) if err != nil { - return err + return } // Consume the client handshake. - hsBuf := make([]byte, clientMaxHandshakeLength) + var hsBuf [clientMaxHandshakeLength]byte for { - n, err := c.conn.Read(hsBuf) + var n int + n, err = c.conn.Read(hsBuf[:]) if err != nil { - return err + return } c.receiveBuffer.Write(hsBuf[:n]) - seed, err := hs.parseClientHandshake(c.receiveBuffer.Bytes()) + var seed []byte + seed, err = hs.parseClientHandshake(c.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { - return err + return } c.receiveBuffer.Reset() @@ -206,46 +237,51 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) } // Generate/send the response. - blob, err := hs.generateHandshake() + var blob []byte + blob, err = hs.generateHandshake() if err != nil { - return err + return } _, err = c.conn.Write(blob) if err != nil { - return err + return } err = c.conn.SetDeadline(time.Time{}) if err != nil { - return err + return } - c.isOk = true + c.state = stateEstablished // TODO: Generate/send the PRNG seed. return nil } +func (c *Obfs4Conn) CanHandshake() bool { + return c.state == stateInit +} + +func (c *Obfs4Conn) CanReadWrite() bool { + return c.state == stateEstablished +} + func (c *Obfs4Conn) ServerHandshake() error { // Handshakes when already established are a no-op. - if c.isOk { + if c.CanReadWrite() { return nil + } else if !c.CanHandshake() { + return syscall.EINVAL } - // Clients handshake as part of Dial. if !c.isServer { panic(fmt.Sprintf("BUG: ServerHandshake() called for client connection")) } - // Regardless of what happens, don't need the listener past returning from - // this routine. - defer func() { - c.listener = nil - }() - // Complete the handshake. err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair) + c.listener = nil if err != nil { c.closeAfterDelay() } @@ -254,7 +290,7 @@ func (c *Obfs4Conn) ServerHandshake() error { } func (c *Obfs4Conn) Read(b []byte) (n int, err error) { - if !c.isOk { + if !c.CanReadWrite() { return 0, syscall.EINVAL } @@ -272,20 +308,19 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) { } func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { - if !c.isOk { + if !c.CanReadWrite() { return 0, syscall.EINVAL } - wrLen := 0 - // If there is buffered payload from earlier Read() calls, write. + wrLen := 0 if c.receiveDecodedBuffer.Len() > 0 { wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes()) if err != nil { - c.isOk = false + c.setBroken() return int64(wrLen), err } else if wrLen < int(c.receiveDecodedBuffer.Len()) { - c.isOk = false + c.setBroken() return int64(wrLen), io.ErrShortWrite } c.receiveDecodedBuffer.Reset() @@ -309,6 +344,17 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { } func (c *Obfs4Conn) Write(b []byte) (n int, err error) { + if !c.CanReadWrite() { + return 0, syscall.EINVAL + } + + defer func() { + if err != nil { + c.setBroken() + } + }() + + // XXX: Change this to write directly to c.conn skipping frameBuf. chopBuf := bytes.NewBuffer(b) var payload [maxPacketPayloadLength]byte var frameBuf bytes.Buffer @@ -318,7 +364,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { rdLen := 0 rdLen, err = chopBuf.Read(payload[:]) if err != nil { - c.isOk = false return 0, err } else if rdLen == 0 { panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) @@ -327,7 +372,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) if err != nil { - c.isOk = false return 0, err } } @@ -340,20 +384,17 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, uint16(padLen-headerLength)) if err != nil { - c.isOk = false return 0, err } } else if padLen > 0 { err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, maxPacketPayloadLength) if err != nil { - c.isOk = false return 0, err } err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, uint16(padLen)) if err != nil { - c.isOk = false return 0, err } } @@ -364,7 +405,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { // Partial writes are fatal because the frame encoder state is advanced // at this point. It's possible to keep frameBuf around, but fuck it. // Someone that wants write timeouts can change this. - c.isOk = false return 0, err } @@ -376,13 +416,13 @@ func (c *Obfs4Conn) Close() error { return syscall.EINVAL } - c.isOk = false + c.state = stateClosed return c.conn.Close() } func (c *Obfs4Conn) LocalAddr() net.Addr { - if !c.isOk { + if c.state == stateClosed { return nil } @@ -390,7 +430,7 @@ func (c *Obfs4Conn) LocalAddr() net.Addr { } func (c *Obfs4Conn) RemoteAddr() net.Addr { - if !c.isOk { + if c.state == stateClosed { return nil } @@ -402,7 +442,7 @@ func (c *Obfs4Conn) SetDeadline(t time.Time) error { } func (c *Obfs4Conn) SetReadDeadline(t time.Time) error { - if !c.isOk { + if !c.CanReadWrite() { return syscall.EINVAL } @@ -487,6 +527,7 @@ func (l *Obfs4Listener) PublicKey() string { if l.keyPair == nil { return "" } + return l.keyPair.Public().Base64() } diff --git a/packet.go b/packet.go index 7b69517..339a86d 100644 --- a/packet.go +++ b/packet.go @@ -67,14 +67,24 @@ func (e InvalidPayloadLengthError) Error() string { var zeroPadBytes [maxPacketPaddingLength]byte -func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error { +func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) { var pkt [framing.MaximumFramePayloadLength]byte + if !c.CanReadWrite() { + return syscall.EINVAL + } + if len(data)+int(padLen) > maxPacketPayloadLength { panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d", len(data), padLen, maxPacketPayloadLength)) } + defer func() { + if err != nil { + c.setBroken() + } + }() + // Packets are: // uint8_t type packetTypePayload (0x00) // uint16_t length Length of the payload (Big Endian). @@ -91,31 +101,32 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe // Encode the packet in an AEAD frame. // TODO: Change Encode to write into frame directly - _, frame, err := c.encoder.Encode(pkt[:pktLen]) + var frame []byte + _, frame, err = c.encoder.Encode(pkt[:pktLen]) if err != nil { // All encoder errors are fatal. - c.isOk = false - return err + return } - wrLen, err := w.Write(frame) + var wrLen int + wrLen, err = w.Write(frame) if err != nil { - c.isOk = false - return err + return } else if wrLen < len(frame) { - c.isOk = false - return io.ErrShortWrite + err = io.ErrShortWrite + return } - return nil + return } func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { - if !c.isOk { + if !c.CanReadWrite() { return n, syscall.EINVAL } var buf [consumeReadSize]byte - rdLen, err := c.conn.Read(buf[:]) + var rdLen int + rdLen, err = c.conn.Read(buf[:]) if err != nil { return } @@ -150,7 +161,8 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { if payloadLen > 0 { if w != nil { // c.WriteTo() skips buffering in c.receiveDecodedBuffer - wrLen, err := w.Write(payload) + var wrLen int + wrLen, err = w.Write(payload) n += wrLen if err != nil { break @@ -176,7 +188,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { // All errors that reach this point are fatal. if err != nil { - c.isOk = false + c.setBroken() } return -- cgit v1.2.3