diff options
-rw-r--r-- | handshake_ntor.go | 15 | ||||
-rw-r--r-- | handshake_ntor_test.go | 34 | ||||
-rw-r--r-- | obfs4.go | 20 |
3 files changed, 38 insertions, 31 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go index 92f00dc..46e2a13 100644 --- a/handshake_ntor.go +++ b/handshake_ntor.go @@ -121,7 +121,7 @@ type clientHandshake struct { serverMark []byte } -func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) { +func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) *clientHandshake { hs := new(clientHandshake) hs.keypair = sessionKey hs.nodeID = nodeID @@ -129,7 +129,7 @@ func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, ses hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength) hs.mac = hmac.New(sha256.New, append(hs.serverIdentity.Bytes()[:], hs.nodeID.Bytes()[:]...)) - return hs, nil + return hs } func (hs *clientHandshake) generateHandshake() ([]byte, error) { @@ -236,8 +236,9 @@ type serverHandshake struct { clientMark []byte } -func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair) *serverHandshake { +func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair, sessionKey *ntor.Keypair) *serverHandshake { hs := new(serverHandshake) + hs.keypair = sessionKey hs.nodeID = nodeID hs.serverIdentity = serverIdentity hs.padLen = csrand.IntRange(serverMinPadLength, serverMaxPadLength) @@ -312,14 +313,6 @@ func (hs *serverHandshake) parseClientHandshake(filter *replayFilter, resp []byt return nil, ErrInvalidHandshake } - // At this point the client knows that we exist, so do the keypair - // generation and complete our side of the handshake. - var err error - hs.keypair, err = ntor.NewKeypair(true) - if err != nil { - return nil, err - } - clientPublic := hs.clientRepresentative.ToPublic() ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair, hs.serverIdentity, hs.nodeID) diff --git a/handshake_ntor_test.go b/handshake_ntor_test.go index 69fb442..2f2ae2e 100644 --- a/handshake_ntor_test.go +++ b/handshake_ntor_test.go @@ -47,10 +47,7 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) } - clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) - if err != nil { - t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err) - } + clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) clientHs.padLen = l // Generate what the client will send to the server. @@ -69,7 +66,11 @@ func TestHandshakeNtor(t *testing.T) { } // Generate the server state and override the pad length. - serverHs := newServerHandshake(nodeID, idKeypair) + serverKeypair, err := ntor.NewKeypair(true) + if err != nil { + t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) + } + serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs.padLen = serverMinPadLength // Parse the client handshake message. @@ -107,10 +108,7 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) } - clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) - if err != nil { - t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err) - } + clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) clientHs.padLen = clientMinPadLength // Generate what the client will send to the server. @@ -123,7 +121,11 @@ func TestHandshakeNtor(t *testing.T) { } // Generate the server state and override the pad length. - serverHs := newServerHandshake(nodeID, idKeypair) + serverKeypair, err := ntor.NewKeypair(true) + if err != nil { + t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) + } + serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs.padLen = l // Parse the client handshake message. @@ -158,7 +160,7 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("ntor.NewKeypair failed: %s", err) } - clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) + clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) if err != nil { t.Fatalf("newClientHandshake failed: %s", err) } @@ -168,7 +170,11 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("clientHandshake.generateHandshake() (forced oversize) failed: %s", err) } - serverHs := newServerHandshake(nodeID, idKeypair) + serverKeypair, err := ntor.NewKeypair(true) + if err != nil { + t.Fatalf("ntor.NewKeypair failed: %s", err) + } + serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") @@ -180,7 +186,7 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("clientHandshake.generateHandshake() (forced undersize) failed: %s", err) } - serverHs = newServerHandshake(nodeID, idKeypair) + serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") @@ -198,7 +204,7 @@ func TestHandshakeNtor(t *testing.T) { if err != nil { t.Fatalf("clientHandshake.generateHandshake() failed: %s", err) } - serverHs = newServerHandshake(nodeID, idKeypair) + serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs.padLen = serverMaxPadLength + inlineSeedFrameLength + 1 _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err != nil { @@ -159,7 +159,6 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK } defer func() { - // The session key is not needed past returning from this routine. c.sessionKey = nil if err != nil { c.setBroken() @@ -169,10 +168,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK // Generate/send the client handshake. var hs *clientHandshake var blob []byte - hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey) - if err != nil { - return - } + hs = newClientHandshake(nodeID, publicKey, c.sessionKey) blob, err = hs.generateHandshake() if err != nil { return @@ -231,12 +227,13 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) } defer func() { + c.sessionKey = nil if err != nil { c.setBroken() } }() - hs := newServerHandshake(nodeID, keypair) + hs := newServerHandshake(nodeID, keypair, c.sessionKey) err = c.conn.SetDeadline(time.Now().Add(connectionTimeout)) if err != nil { return @@ -645,6 +642,17 @@ func (l *Obfs4Listener) AcceptObfs4() (*Obfs4Conn, error) { // Allocate the obfs4 connection state. cObfs := new(Obfs4Conn) + + // Generate the session keypair *before* consuming data from the peer, to + // add more noise to the keypair generation time. The idea is that jitter + // here is masked by network latency (the time it takes for a server to + // accept a socket out of the backlog should not be fixed, and the client + // needs to send the public key). + cObfs.sessionKey, err = ntor.NewKeypair(true) + if err != nil { + return nil, err + } + cObfs.conn = c cObfs.isServer = true cObfs.listener = l |