summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--handshake_ntor.go9
-rw-r--r--handshake_ntor_test.go24
-rw-r--r--obfs4.go14
3 files changed, 33 insertions, 14 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go
index fc107c2..92f00dc 100644
--- a/handshake_ntor.go
+++ b/handshake_ntor.go
@@ -121,14 +121,9 @@ type clientHandshake struct {
serverMark []byte
}
-func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) {
- var err error
-
+func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) {
hs := new(clientHandshake)
- hs.keypair, err = ntor.NewKeypair(true)
- if err != nil {
- return nil, err
- }
+ hs.keypair = sessionKey
hs.nodeID = nodeID
hs.serverIdentity = serverIdentity
hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)
diff --git a/handshake_ntor_test.go b/handshake_ntor_test.go
index b3e0a4d..69fb442 100644
--- a/handshake_ntor_test.go
+++ b/handshake_ntor_test.go
@@ -43,9 +43,13 @@ func TestHandshakeNtor(t *testing.T) {
// Test client handshake padding.
for l := clientMinPadLength; l <= clientMaxPadLength; l++ {
// Generate the client state and override the pad length.
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+ clientKeypair, err := ntor.NewKeypair(true)
if err != nil {
- t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
+ t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+ }
+ clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
+ if err != nil {
+ t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = l
@@ -99,9 +103,13 @@ func TestHandshakeNtor(t *testing.T) {
// Test server handshake padding.
for l := serverMinPadLength; l <= serverMaxPadLength+inlineSeedFrameLength; l++ {
// Generate the client state and override the pad length.
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+ clientKeypair, err := ntor.NewKeypair(true)
+ if err != nil {
+ t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+ }
+ clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
- t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
+ t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = clientMinPadLength
@@ -146,9 +154,13 @@ func TestHandshakeNtor(t *testing.T) {
}
// Test oversized client padding.
- clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+ clientKeypair, err := ntor.NewKeypair(true)
+ if err != nil {
+ t.Fatalf("ntor.NewKeypair failed: %s", err)
+ }
+ clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
- t.Fatalf("newClientHandshake failed:", err)
+ t.Fatalf("newClientHandshake failed: %s", err)
}
clientHs.padLen = clientMaxPadLength + 1
diff --git a/obfs4.go b/obfs4.go
index c780e0c..cc5e3b9 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -69,6 +69,8 @@ const (
type Obfs4Conn struct {
conn net.Conn
+ sessionKey *ntor.Keypair
+
lenProbDist *wDist
iatProbDist *wDist
@@ -157,6 +159,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
}
defer func() {
+ // The session key is not needed past returning from this routine.
+ c.sessionKey = nil
if err != nil {
c.setBroken()
}
@@ -165,7 +169,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
// Generate/send the client handshake.
var hs *clientHandshake
var blob []byte
- hs, err = newClientHandshake(nodeID, publicKey)
+ hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey)
if err != nil {
return
}
@@ -576,6 +580,14 @@ func DialObfs4DialFn(dialFn DialFn, network, address, nodeID, publicKey string,
}
c.iatProbDist = newWDist(iatSeed, 0, maxIatDelay)
}
+
+ // Generate the session keypair *before* connecting to the remote peer.
+ c.sessionKey, err = ntor.NewKeypair(true)
+ if err != nil {
+ return nil, err
+ }
+
+ // Connect to the remote peer.
c.conn, err = dialFn(network, address)
if err != nil {
return nil, err