summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--handshake_ntor.go15
-rw-r--r--handshake_ntor_test.go34
-rw-r--r--obfs4.go20
3 files changed, 38 insertions, 31 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go
index 92f00dc..46e2a13 100644
--- a/handshake_ntor.go
+++ b/handshake_ntor.go
@@ -121,7 +121,7 @@ type clientHandshake struct {
serverMark []byte
}
-func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) {
+func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) *clientHandshake {
hs := new(clientHandshake)
hs.keypair = sessionKey
hs.nodeID = nodeID
@@ -129,7 +129,7 @@ func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, ses
hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)
hs.mac = hmac.New(sha256.New, append(hs.serverIdentity.Bytes()[:], hs.nodeID.Bytes()[:]...))
- return hs, nil
+ return hs
}
func (hs *clientHandshake) generateHandshake() ([]byte, error) {
@@ -236,8 +236,9 @@ type serverHandshake struct {
clientMark []byte
}
-func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair) *serverHandshake {
+func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair, sessionKey *ntor.Keypair) *serverHandshake {
hs := new(serverHandshake)
+ hs.keypair = sessionKey
hs.nodeID = nodeID
hs.serverIdentity = serverIdentity
hs.padLen = csrand.IntRange(serverMinPadLength, serverMaxPadLength)
@@ -312,14 +313,6 @@ func (hs *serverHandshake) parseClientHandshake(filter *replayFilter, resp []byt
return nil, ErrInvalidHandshake
}
- // At this point the client knows that we exist, so do the keypair
- // generation and complete our side of the handshake.
- var err error
- hs.keypair, err = ntor.NewKeypair(true)
- if err != nil {
- return nil, err
- }
-
clientPublic := hs.clientRepresentative.ToPublic()
ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair,
hs.serverIdentity, hs.nodeID)
diff --git a/handshake_ntor_test.go b/handshake_ntor_test.go
index 69fb442..2f2ae2e 100644
--- a/handshake_ntor_test.go
+++ b/handshake_ntor_test.go
@@ -47,10 +47,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
- if err != nil {
- t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
- }
+ clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
clientHs.padLen = l
// Generate what the client will send to the server.
@@ -69,7 +66,11 @@ func TestHandshakeNtor(t *testing.T) {
}
// Generate the server state and override the pad length.
- serverHs := newServerHandshake(nodeID, idKeypair)
+ serverKeypair, err := ntor.NewKeypair(true)
+ if err != nil {
+ t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+ }
+ serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = serverMinPadLength
// Parse the client handshake message.
@@ -107,10 +108,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
- if err != nil {
- t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
- }
+ clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
clientHs.padLen = clientMinPadLength
// Generate what the client will send to the server.
@@ -123,7 +121,11 @@ func TestHandshakeNtor(t *testing.T) {
}
// Generate the server state and override the pad length.
- serverHs := newServerHandshake(nodeID, idKeypair)
+ serverKeypair, err := ntor.NewKeypair(true)
+ if err != nil {
+ t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+ }
+ serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = l
// Parse the client handshake message.
@@ -158,7 +160,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("ntor.NewKeypair failed: %s", err)
}
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
+ clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("newClientHandshake failed: %s", err)
}
@@ -168,7 +170,11 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("clientHandshake.generateHandshake() (forced oversize) failed: %s", err)
}
- serverHs := newServerHandshake(nodeID, idKeypair)
+ serverKeypair, err := ntor.NewKeypair(true)
+ if err != nil {
+ t.Fatalf("ntor.NewKeypair failed: %s", err)
+ }
+ serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)")
@@ -180,7 +186,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("clientHandshake.generateHandshake() (forced undersize) failed: %s", err)
}
- serverHs = newServerHandshake(nodeID, idKeypair)
+ serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)")
@@ -198,7 +204,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil {
t.Fatalf("clientHandshake.generateHandshake() failed: %s", err)
}
- serverHs = newServerHandshake(nodeID, idKeypair)
+ serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = serverMaxPadLength + inlineSeedFrameLength + 1
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err != nil {
diff --git a/obfs4.go b/obfs4.go
index cc5e3b9..b34eceb 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -159,7 +159,6 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
}
defer func() {
- // The session key is not needed past returning from this routine.
c.sessionKey = nil
if err != nil {
c.setBroken()
@@ -169,10 +168,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
// Generate/send the client handshake.
var hs *clientHandshake
var blob []byte
- hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey)
- if err != nil {
- return
- }
+ hs = newClientHandshake(nodeID, publicKey, c.sessionKey)
blob, err = hs.generateHandshake()
if err != nil {
return
@@ -231,12 +227,13 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair)
}
defer func() {
+ c.sessionKey = nil
if err != nil {
c.setBroken()
}
}()
- hs := newServerHandshake(nodeID, keypair)
+ hs := newServerHandshake(nodeID, keypair, c.sessionKey)
err = c.conn.SetDeadline(time.Now().Add(connectionTimeout))
if err != nil {
return
@@ -645,6 +642,17 @@ func (l *Obfs4Listener) AcceptObfs4() (*Obfs4Conn, error) {
// Allocate the obfs4 connection state.
cObfs := new(Obfs4Conn)
+
+ // Generate the session keypair *before* consuming data from the peer, to
+ // add more noise to the keypair generation time. The idea is that jitter
+ // here is masked by network latency (the time it takes for a server to
+ // accept a socket out of the backlog should not be fixed, and the client
+ // needs to send the public key).
+ cObfs.sessionKey, err = ntor.NewKeypair(true)
+ if err != nil {
+ return nil, err
+ }
+
cObfs.conn = c
cObfs.isServer = true
cObfs.listener = l