diff options
Diffstat (limited to 'obfs4.go')
-rw-r--r-- | obfs4.go | 145 |
1 files changed, 93 insertions, 52 deletions
@@ -51,6 +51,15 @@ const ( maxCloseInterval = 60 ) +type connState int + +const ( + stateInit connState = iota + stateEstablished + stateBroken + stateClosed +) + // Obfs4Conn is the implementation of the net.Conn interface for obfs4 // connections. type Obfs4Conn struct { @@ -64,7 +73,7 @@ type Obfs4Conn struct { receiveBuffer bytes.Buffer receiveDecodedBuffer bytes.Buffer - isOk bool + state connState isServer bool // Server side state. @@ -111,51 +120,65 @@ func (c *Obfs4Conn) closeAfterDelay() { } } -func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) error { +func (c *Obfs4Conn) setBroken() { + c.state = stateBroken +} + +func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicKey) (err error) { if c.isServer { panic(fmt.Sprintf("BUG: clientHandshake() called for server connection")) } + defer func() { + if err != nil { + c.setBroken() + } + }() + // Generate/send the client handshake. - hs, err := newClientHandshake(nodeID, publicKey) + var hs *clientHandshake + var blob []byte + hs, err = newClientHandshake(nodeID, publicKey) if err != nil { - return err + return } - blob, err := hs.generateHandshake() + blob, err = hs.generateHandshake() if err != nil { - return err + return } err = c.conn.SetDeadline(time.Now().Add(connectionTimeout * 2)) if err != nil { - return err + return } _, err = c.conn.Write(blob) if err != nil { - return err + return } // Consume the server handshake. - hsBuf := make([]byte, serverMaxHandshakeLength) + var hsBuf [serverMaxHandshakeLength]byte for { - n, err := c.conn.Read(hsBuf) + var n int + n, err = c.conn.Read(hsBuf[:]) if err != nil { - return err + return } c.receiveBuffer.Write(hsBuf[:n]) - n, seed, err := hs.parseServerHandshake(c.receiveBuffer.Bytes()) + var seed []byte + n, seed, err = hs.parseServerHandshake(c.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { - return err + return } _ = c.receiveBuffer.Next(n) err = c.conn.SetDeadline(time.Time{}) if err != nil { - return err + return } // Use the derived key material to intialize the link crypto. @@ -163,37 +186,45 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK c.encoder = framing.NewEncoder(okm[:framing.KeyLength]) c.decoder = framing.NewDecoder(okm[framing.KeyLength:]) - c.isOk = true + c.state = stateEstablished return nil } } -func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) error { +func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) (err error) { if !c.isServer { panic(fmt.Sprintf("BUG: serverHandshake() called for client connection")) } + defer func() { + if err != nil { + c.setBroken() + } + }() + hs := newServerHandshake(nodeID, keypair) - err := c.conn.SetDeadline(time.Now().Add(connectionTimeout)) + err = c.conn.SetDeadline(time.Now().Add(connectionTimeout)) if err != nil { - return err + return } // Consume the client handshake. - hsBuf := make([]byte, clientMaxHandshakeLength) + var hsBuf [clientMaxHandshakeLength]byte for { - n, err := c.conn.Read(hsBuf) + var n int + n, err = c.conn.Read(hsBuf[:]) if err != nil { - return err + return } c.receiveBuffer.Write(hsBuf[:n]) - seed, err := hs.parseClientHandshake(c.receiveBuffer.Bytes()) + var seed []byte + seed, err = hs.parseClientHandshake(c.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { - return err + return } c.receiveBuffer.Reset() @@ -206,46 +237,51 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) } // Generate/send the response. - blob, err := hs.generateHandshake() + var blob []byte + blob, err = hs.generateHandshake() if err != nil { - return err + return } _, err = c.conn.Write(blob) if err != nil { - return err + return } err = c.conn.SetDeadline(time.Time{}) if err != nil { - return err + return } - c.isOk = true + c.state = stateEstablished // TODO: Generate/send the PRNG seed. return nil } +func (c *Obfs4Conn) CanHandshake() bool { + return c.state == stateInit +} + +func (c *Obfs4Conn) CanReadWrite() bool { + return c.state == stateEstablished +} + func (c *Obfs4Conn) ServerHandshake() error { // Handshakes when already established are a no-op. - if c.isOk { + if c.CanReadWrite() { return nil + } else if !c.CanHandshake() { + return syscall.EINVAL } - // Clients handshake as part of Dial. if !c.isServer { panic(fmt.Sprintf("BUG: ServerHandshake() called for client connection")) } - // Regardless of what happens, don't need the listener past returning from - // this routine. - defer func() { - c.listener = nil - }() - // Complete the handshake. err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair) + c.listener = nil if err != nil { c.closeAfterDelay() } @@ -254,7 +290,7 @@ func (c *Obfs4Conn) ServerHandshake() error { } func (c *Obfs4Conn) Read(b []byte) (n int, err error) { - if !c.isOk { + if !c.CanReadWrite() { return 0, syscall.EINVAL } @@ -272,20 +308,19 @@ func (c *Obfs4Conn) Read(b []byte) (n int, err error) { } func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { - if !c.isOk { + if !c.CanReadWrite() { return 0, syscall.EINVAL } - wrLen := 0 - // If there is buffered payload from earlier Read() calls, write. + wrLen := 0 if c.receiveDecodedBuffer.Len() > 0 { wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes()) if err != nil { - c.isOk = false + c.setBroken() return int64(wrLen), err } else if wrLen < int(c.receiveDecodedBuffer.Len()) { - c.isOk = false + c.setBroken() return int64(wrLen), io.ErrShortWrite } c.receiveDecodedBuffer.Reset() @@ -309,6 +344,17 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { } func (c *Obfs4Conn) Write(b []byte) (n int, err error) { + if !c.CanReadWrite() { + return 0, syscall.EINVAL + } + + defer func() { + if err != nil { + c.setBroken() + } + }() + + // XXX: Change this to write directly to c.conn skipping frameBuf. chopBuf := bytes.NewBuffer(b) var payload [maxPacketPayloadLength]byte var frameBuf bytes.Buffer @@ -318,7 +364,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { rdLen := 0 rdLen, err = chopBuf.Read(payload[:]) if err != nil { - c.isOk = false return 0, err } else if rdLen == 0 { panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) @@ -327,7 +372,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) if err != nil { - c.isOk = false return 0, err } } @@ -340,20 +384,17 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, uint16(padLen-headerLength)) if err != nil { - c.isOk = false return 0, err } } else if padLen > 0 { err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, maxPacketPayloadLength) if err != nil { - c.isOk = false return 0, err } err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, uint16(padLen)) if err != nil { - c.isOk = false return 0, err } } @@ -364,7 +405,6 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { // Partial writes are fatal because the frame encoder state is advanced // at this point. It's possible to keep frameBuf around, but fuck it. // Someone that wants write timeouts can change this. - c.isOk = false return 0, err } @@ -376,13 +416,13 @@ func (c *Obfs4Conn) Close() error { return syscall.EINVAL } - c.isOk = false + c.state = stateClosed return c.conn.Close() } func (c *Obfs4Conn) LocalAddr() net.Addr { - if !c.isOk { + if c.state == stateClosed { return nil } @@ -390,7 +430,7 @@ func (c *Obfs4Conn) LocalAddr() net.Addr { } func (c *Obfs4Conn) RemoteAddr() net.Addr { - if !c.isOk { + if c.state == stateClosed { return nil } @@ -402,7 +442,7 @@ func (c *Obfs4Conn) SetDeadline(t time.Time) error { } func (c *Obfs4Conn) SetReadDeadline(t time.Time) error { - if !c.isOk { + if !c.CanReadWrite() { return syscall.EINVAL } @@ -487,6 +527,7 @@ func (l *Obfs4Listener) PublicKey() string { if l.keyPair == nil { return "" } + return l.keyPair.Public().Base64() } |