diff options
-rw-r--r-- | obfs4.go | 19 |
1 files changed, 12 insertions, 7 deletions
@@ -32,6 +32,7 @@ import ( "bytes" "fmt" "io" + "math/rand" "net" "syscall" "time" @@ -76,8 +77,6 @@ type Obfs4Conn struct { // Server side state. listener *Obfs4Listener startTime time.Time - closeDelayBytes int - closeDelay int } func (c *Obfs4Conn) padBurst(burst *bytes.Buffer) (err error) { @@ -117,7 +116,7 @@ func (c *Obfs4Conn) closeAfterDelay() { // I-it's not like I w-wanna handshake with you or anything. B-b-baka! defer c.conn.Close() - delay := time.Duration(c.closeDelay) * time.Second + delay := time.Duration(c.listener.closeDelay) * time.Second + connectionTimeout deadline := c.startTime.Add(delay) if time.Now().After(deadline) { return @@ -132,7 +131,7 @@ func (c *Obfs4Conn) closeAfterDelay() { // interval passes or a certain size has been reached. discarded := 0 var buf [framing.MaximumSegmentLength]byte - for discarded < int(c.closeDelayBytes) { + for discarded < int(c.listener.closeDelayBytes) { n, err := c.conn.Read(buf[:]) if err != nil { return @@ -325,10 +324,10 @@ func (c *Obfs4Conn) ServerHandshake() error { // Complete the handshake. err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair) - c.listener = nil if err != nil { c.closeAfterDelay() } + c.listener = nil return err } @@ -524,7 +523,11 @@ type Obfs4Listener struct { keyPair *ntor.Keypair nodeID *ntor.NodeID + seed *DrbgSeed + + closeDelayBytes int + closeDelay int } func (l *Obfs4Listener) Accept() (net.Conn, error) { @@ -545,8 +548,6 @@ func (l *Obfs4Listener) Accept() (net.Conn, error) { return nil, err } cObfs.startTime = time.Now() - cObfs.closeDelayBytes = cObfs.lenProbDist.rng.Intn(maxCloseDelayBytes) - cObfs.closeDelay = cObfs.lenProbDist.rng.Intn(maxCloseDelay) return cObfs, nil } @@ -585,6 +586,10 @@ func Listen(network, laddr, nodeID, privateKey, seed string) (net.Listener, erro return nil, err } + rng := rand.New(newHashDrbg(l.seed)) + l.closeDelayBytes = rng.Intn(maxCloseDelayBytes) + l.closeDelay = rng.Intn(maxCloseDelay) + // Start up the listener. l.listener, err = net.Listen(network, laddr) if err != nil { |