summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-05-20 16:40:09 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-05-20 16:40:09 +0000
commit66edb786dbfd69c1fd44176c81bdbfddff28059f (patch)
tree542de37746fb6b7199badef0544c8d748d5ec68b
parenta853a1f0aa10c7a096c349eb893214f8aadbf29e (diff)
Tweak the obfs4 handshake code.
* Fixed where the code wasn't ensuring that the MAC_[C,S] was present. * Optimized the server side to only look at the tail of the (possibly incomplete handshakeRequest).
-rw-r--r--handshake_ntor.go48
1 files changed, 38 insertions, 10 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go
index b8fd222..1aa00bc 100644
--- a/handshake_ntor.go
+++ b/handshake_ntor.go
@@ -166,15 +166,15 @@ func (hs *clientHandshake) parseServerHandshake(resp []byte) (int, []byte, error
hs.serverAuth = new(ntor.Auth)
copy(hs.serverAuth.Bytes()[:], resp[ntor.RepresentativeLength:])
- // Derive the mark
+ // Derive the mark.
hs.mac.Reset()
hs.mac.Write(hs.serverRepresentative.Bytes()[:])
hs.serverMark = hs.mac.Sum(nil)[:markLength]
}
// Attempt to find the mark + MAC.
- pos := findMark(hs.serverMark, resp, ntor.RepresentativeLength+ntor.AuthLength+serverMinPadLength,
- serverMaxHandshakeLength)
+ pos := findMarkMac(hs.serverMark, resp, ntor.RepresentativeLength+ntor.AuthLength+serverMinPadLength,
+ serverMaxHandshakeLength, false)
if pos == -1 {
if len(resp) >= serverMaxHandshakeLength {
return 0, nil, ErrInvalidHandshake
@@ -240,15 +240,15 @@ func (hs *serverHandshake) parseClientHandshake(resp []byte) ([]byte, error) {
hs.clientRepresentative = new(ntor.Representative)
copy(hs.clientRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength])
- // Derive the mark
+ // Derive the mark.
hs.mac.Reset()
hs.mac.Write(hs.clientRepresentative.Bytes()[:])
hs.clientMark = hs.mac.Sum(nil)[:markLength]
}
// Attempt to find the mark + MAC.
- pos := findMark(hs.clientMark, resp, ntor.RepresentativeLength+clientMinPadLength,
- serverMaxHandshakeLength)
+ pos := findMarkMac(hs.clientMark, resp, ntor.RepresentativeLength+clientMinPadLength,
+ clientMaxHandshakeLength, true)
if pos == -1 {
if len(resp) >= clientMaxHandshakeLength {
return nil, ErrInvalidHandshake
@@ -351,23 +351,51 @@ func getEpochHour() int64 {
return time.Now().Unix() / 3600
}
-func findMark(mark, buf []byte, startPos, maxPos int) int {
+func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int) {
+ if len(mark) != markLength {
+ panic(fmt.Sprintf("BUG: Invalid mark length: %d", len(mark)))
+ }
+
endPos := len(buf)
+ if startPos > len(buf) {
+ return -1
+ }
if endPos > maxPos {
endPos = maxPos
}
- if startPos > len(buf) {
+ if endPos - startPos < markLength + macLength {
return -1
}
+ if fromTail {
+ // The server can optimize the search process by only examining the
+ // tail of the buffer. The client can't send valid data past M_C |
+ // MAC_C as it does not have the server's public key yet.
+ pos = endPos - (markLength + macLength)
+ if !hmac.Equal(buf[pos:pos+markLength], mark) {
+ return -1
+ }
+
+ return
+ }
+
+ // The client has to actually do a substring search since the server can
+ // and will send payload trailing the response.
+ //
// XXX: bytes.Index() uses a naive search, which kind of sucks.
- pos := bytes.Index(buf[startPos:endPos], mark)
+ pos = bytes.Index(buf[startPos:endPos], mark)
if pos == -1 {
return -1
}
+ // Ensure that there is enough trailing data for the MAC.
+ if startPos + pos + markLength + macLength > endPos {
+ return -1
+ }
+
// Return the index relative to the start of the slice.
- return pos + startPos
+ pos += startPos
+ return
}
func makePad(min, max int) ([]byte, error) {