diff options
Diffstat (limited to 'handshake_ntor.go')
-rw-r--r-- | handshake_ntor.go | 48 |
1 files changed, 38 insertions, 10 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go index b8fd222..1aa00bc 100644 --- a/handshake_ntor.go +++ b/handshake_ntor.go @@ -166,15 +166,15 @@ func (hs *clientHandshake) parseServerHandshake(resp []byte) (int, []byte, error hs.serverAuth = new(ntor.Auth) copy(hs.serverAuth.Bytes()[:], resp[ntor.RepresentativeLength:]) - // Derive the mark + // Derive the mark. hs.mac.Reset() hs.mac.Write(hs.serverRepresentative.Bytes()[:]) hs.serverMark = hs.mac.Sum(nil)[:markLength] } // Attempt to find the mark + MAC. - pos := findMark(hs.serverMark, resp, ntor.RepresentativeLength+ntor.AuthLength+serverMinPadLength, - serverMaxHandshakeLength) + pos := findMarkMac(hs.serverMark, resp, ntor.RepresentativeLength+ntor.AuthLength+serverMinPadLength, + serverMaxHandshakeLength, false) if pos == -1 { if len(resp) >= serverMaxHandshakeLength { return 0, nil, ErrInvalidHandshake @@ -240,15 +240,15 @@ func (hs *serverHandshake) parseClientHandshake(resp []byte) ([]byte, error) { hs.clientRepresentative = new(ntor.Representative) copy(hs.clientRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength]) - // Derive the mark + // Derive the mark. hs.mac.Reset() hs.mac.Write(hs.clientRepresentative.Bytes()[:]) hs.clientMark = hs.mac.Sum(nil)[:markLength] } // Attempt to find the mark + MAC. - pos := findMark(hs.clientMark, resp, ntor.RepresentativeLength+clientMinPadLength, - serverMaxHandshakeLength) + pos := findMarkMac(hs.clientMark, resp, ntor.RepresentativeLength+clientMinPadLength, + clientMaxHandshakeLength, true) if pos == -1 { if len(resp) >= clientMaxHandshakeLength { return nil, ErrInvalidHandshake @@ -351,23 +351,51 @@ func getEpochHour() int64 { return time.Now().Unix() / 3600 } -func findMark(mark, buf []byte, startPos, maxPos int) int { +func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int) { + if len(mark) != markLength { + panic(fmt.Sprintf("BUG: Invalid mark length: %d", len(mark))) + } + endPos := len(buf) + if startPos > len(buf) { + return -1 + } if endPos > maxPos { endPos = maxPos } - if startPos > len(buf) { + if endPos - startPos < markLength + macLength { return -1 } + if fromTail { + // The server can optimize the search process by only examining the + // tail of the buffer. The client can't send valid data past M_C | + // MAC_C as it does not have the server's public key yet. + pos = endPos - (markLength + macLength) + if !hmac.Equal(buf[pos:pos+markLength], mark) { + return -1 + } + + return + } + + // The client has to actually do a substring search since the server can + // and will send payload trailing the response. + // // XXX: bytes.Index() uses a naive search, which kind of sucks. - pos := bytes.Index(buf[startPos:endPos], mark) + pos = bytes.Index(buf[startPos:endPos], mark) if pos == -1 { return -1 } + // Ensure that there is enough trailing data for the MAC. + if startPos + pos + markLength + macLength > endPos { + return -1 + } + // Return the index relative to the start of the slice. - return pos + startPos + pos += startPos + return } func makePad(min, max int) ([]byte, error) { |