diff options
Diffstat (limited to 'handshake_ntor.go')
-rw-r--r-- | handshake_ntor.go | 386 |
1 files changed, 386 insertions, 0 deletions
diff --git a/handshake_ntor.go b/handshake_ntor.go new file mode 100644 index 0000000..44680aa --- /dev/null +++ b/handshake_ntor.go @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2014, Yawning Angel <yawning at torproject dot org> + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package obfs4 + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash" + "math/big" + "strconv" + "time" + + "github.com/yawning/obfs4/framing" + "github.com/yawning/obfs4/ntor" +) + +const ( + clientMinPadLength = serverMinHandshakeLength - clientMinHandshakeLength + clientMaxPadLength = framing.MaximumSegmentLength - clientMinHandshakeLength + clientMinHandshakeLength = ntor.RepresentativeLength + markLength + macLength + clientMaxHandshakeLength = framing.MaximumSegmentLength + + serverMinPadLength = 0 + serverMaxPadLength = framing.MaximumSegmentLength - serverMinHandshakeLength + serverMinHandshakeLength = ntor.RepresentativeLength + ntor.AuthLength + + markLength + macLength + serverMaxHandshakeLength = framing.MaximumSegmentLength + + markLength = sha256.Size + macLength = sha256.Size +) + +var ErrMarkNotFoundYet = errors.New("handshake: M_[C,S] not found yet") +var ErrInvalidHandshake = errors.New("handshake: Failed to find M_[C,S]") +var ErrNtorFailed = errors.New("handshake: ntor handshake failure") + +type InvalidMacError struct { + Derived []byte + Received []byte +} + +func (e *InvalidMacError) Error() string { + return fmt.Sprintf("handshake: MAC mismatch: Dervied: %s Received: %s.", + hex.EncodeToString(e.Derived), hex.EncodeToString(e.Received)) +} + +type InvalidAuthError struct { + Derived *ntor.Auth + Received *ntor.Auth +} + +func (e *InvalidAuthError) Error() string { + return fmt.Sprintf("handshake: ntor AUTH mismatch: Derived: %s Received:%s.", + hex.EncodeToString(e.Derived.Bytes()[:]), + hex.EncodeToString(e.Received.Bytes()[:])) +} + +type clientHandshake struct { + keypair *ntor.Keypair + nodeID *ntor.NodeID + serverIdentity *ntor.PublicKey + epochHour []byte + + mac hash.Hash + + serverRepresentative *ntor.Representative + serverAuth *ntor.Auth + serverMark []byte +} + +func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) { + var err error + + hs := new(clientHandshake) + hs.keypair, err = ntor.NewKeypair(true) + if err != nil { + return nil, err + } + hs.nodeID = nodeID + hs.serverIdentity = serverIdentity + hs.mac = hmac.New(sha256.New, hs.serverIdentity.Bytes()[:]) + + return hs, nil +} + +func (hs *clientHandshake) generateHandshake() ([]byte, error) { + var buf bytes.Buffer + + hs.mac.Reset() + hs.mac.Write(hs.keypair.Representative().Bytes()[:]) + mark := hs.mac.Sum(nil) + + // The client handshake is X | P_C | M_C | MAC(X | P_C | M_C | E) where: + // * X is the client's ephemeral Curve25519 public key representative. + // * P_C is [0,clientMaxPadLength] bytes of random padding. + // * M_C is HMAC-SHA256(serverIdentity, X) + // * MAC is HMAC-SHA256(serverIdentity, X .... E) + // * E is the string representation of the number of hours since the UNIX + // epoch. + + // Generate the padding + pad, err := makePad(clientMinPadLength, clientMaxPadLength) + if err != nil { + return nil, err + } + + // Write X, P_C, M_C. + buf.Write(hs.keypair.Representative().Bytes()[:]) + buf.Write(pad) + buf.Write(mark) + + // Calculate and write the MAC. + hs.mac.Reset() + hs.mac.Write(buf.Bytes()) + hs.epochHour = []byte(strconv.FormatInt(getEpochHour(), 10)) + hs.mac.Write(hs.epochHour) + buf.Write(hs.mac.Sum(nil)) + + return buf.Bytes(), nil +} + +func (hs *clientHandshake) parseServerHandshake(resp []byte) (int, []byte, error) { + // No point in examining the data unless the miminum plausible response has + // been received. + if serverMinHandshakeLength > len(resp) { + return 0, nil, ErrMarkNotFoundYet + } + + if hs.serverRepresentative == nil || hs.serverAuth == nil { + // Pull out the representative/AUTH. (XXX: Add ctors to ntor) + hs.serverRepresentative = new(ntor.Representative) + copy(hs.serverRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength]) + hs.serverAuth = new(ntor.Auth) + copy(hs.serverAuth.Bytes()[:], resp[ntor.RepresentativeLength:]) + + // Derive the mark + hs.mac.Reset() + hs.mac.Write(hs.serverRepresentative.Bytes()[:]) + hs.serverMark = hs.mac.Sum(nil) + } + + // Attempt to find the mark + MAC. + pos := findMark(hs.serverMark, resp, + ntor.RepresentativeLength+ntor.AuthLength, serverMaxHandshakeLength) + if pos == -1 { + if len(resp) >= serverMaxHandshakeLength { + return 0, nil, ErrInvalidHandshake + } + return 0, nil, ErrMarkNotFoundYet + } + + // Validate the MAC. + hs.mac.Reset() + hs.mac.Write(resp[:pos+markLength]) + hs.mac.Write(hs.epochHour) + macCmp := hs.mac.Sum(nil) + macRx := resp[pos+markLength : pos+markLength+macLength] + if !hmac.Equal(macCmp, macRx) { + return 0, nil, &InvalidMacError{macCmp, macRx} + } + + // Complete the handshake. + serverPublic := hs.serverRepresentative.ToPublic() + ok, seed, auth := ntor.ClientHandshake(hs.keypair, serverPublic, + hs.serverIdentity, hs.nodeID) + if !ok { + return 0, nil, ErrNtorFailed + } + if !ntor.CompareAuth(auth, hs.serverAuth.Bytes()[:]) { + return 0, nil, &InvalidAuthError{auth, hs.serverAuth} + } + + return pos + markLength + macLength, seed.Bytes()[:], nil +} + +type serverHandshake struct { + keypair *ntor.Keypair + nodeID *ntor.NodeID + serverIdentity *ntor.Keypair + epochHour []byte + serverAuth *ntor.Auth + + mac hash.Hash + + clientRepresentative *ntor.Representative + clientMark []byte +} + +func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair) *serverHandshake { + hs := new(serverHandshake) + hs.nodeID = nodeID + hs.serverIdentity = serverIdentity + hs.mac = hmac.New(sha256.New, hs.serverIdentity.Public().Bytes()[:]) + + return hs +} + +func (hs *serverHandshake) parseClientHandshake(resp []byte) ([]byte, error) { + // No point in examining the data unless the miminum plausible response has + // been received. + if clientMinHandshakeLength > len(resp) { + return nil, ErrMarkNotFoundYet + } + + if hs.clientRepresentative == nil { + // Pull out the representative/AUTH. (XXX: Add ctors to ntor) + hs.clientRepresentative = new(ntor.Representative) + copy(hs.clientRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength]) + + // Derive the mark + hs.mac.Reset() + hs.mac.Write(hs.clientRepresentative.Bytes()[:]) + hs.clientMark = hs.mac.Sum(nil) + } + + // Attempt to find the mark + MAC. + pos := findMark(hs.clientMark, resp, ntor.RepresentativeLength, + serverMaxHandshakeLength) + if pos == -1 { + if len(resp) >= clientMaxHandshakeLength { + return nil, ErrInvalidHandshake + } + return nil, ErrMarkNotFoundYet + } + + // Validate the MAC. + macFound := false + for _, off := range []int64{0, -1, 1} { + // Allow epoch to be off by up to a hour in either direction. + epochHour := []byte(strconv.FormatInt(getEpochHour()+int64(off), 10)) + hs.mac.Reset() + hs.mac.Write(resp[:pos+markLength]) + hs.mac.Write(epochHour) + macCmp := hs.mac.Sum(nil) + macRx := resp[pos+markLength : pos+markLength+macLength] + if hmac.Equal(macCmp, macRx) { + macFound = true + hs.epochHour = epochHour + + // In theory, we should always evaluate all 3 MACs, but at this + // point we are reasonably confident that the client knows the + // correct NodeID/Public key, and if this fails, we just ignore the + // client for a random interval and drop the connection anyway. + break + } + } + if !macFound { + // This probably should be an InvalidMacError, but conveying the 3 MACS + // that would be accepted is annoying so just return a generic fatal + // failure. + return nil, ErrInvalidHandshake + } + + // Client should never sent trailing garbage. + if len(resp) != pos+markLength+macLength { + return nil, ErrInvalidHandshake + } + + // At this point the client knows that we exist, so do the keypair + // generation and complete our side of the handshake. + var err error + hs.keypair, err = ntor.NewKeypair(true) + if err != nil { + return nil, err + } + + clientPublic := hs.clientRepresentative.ToPublic() + ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair, + hs.serverIdentity, hs.nodeID) + if !ok { + return nil, ErrNtorFailed + } + hs.serverAuth = auth + + return seed.Bytes()[:], nil +} + +func (hs *serverHandshake) generateHandshake() ([]byte, error) { + var buf bytes.Buffer + + hs.mac.Reset() + hs.mac.Write(hs.keypair.Representative().Bytes()[:]) + mark := hs.mac.Sum(nil) + + // The server handshake is Y | AUTH | P_S | M_S | MAC(Y | AUTH | P_S | M_S | E) where: + // * Y is the server's ephemeral Curve25519 public key representative. + // * AUTH is the ntor handshake AUTH value. + // * P_S is [0,serverMaxPadLength] bytes of random padding. + // * M_S is HMAC-SHA256(serverIdentity, Y) + // * MAC is HMAC-SHA256(serverIdentity, Y .... E) + // * E is the string representation of the number of hours since the UNIX + // epoch. + + // Generate the padding + pad, err := makePad(serverMinPadLength, serverMaxPadLength) + if err != nil { + return nil, err + } + + // Write Y, AUTH, P_S, M_S. + buf.Write(hs.keypair.Representative().Bytes()[:]) + buf.Write(hs.serverAuth.Bytes()[:]) + buf.Write(pad) + buf.Write(mark) + + // Calculate and write the MAC. + hs.mac.Reset() + hs.mac.Write(buf.Bytes()) + hs.epochHour = []byte(strconv.FormatInt(getEpochHour(), 10)) + hs.mac.Write(hs.epochHour) + buf.Write(hs.mac.Sum(nil)) + + return buf.Bytes(), nil +} + +// getEpochHour returns the number of hours since the UNIX epoch. +func getEpochHour() int64 { + return time.Now().Unix() / 3600 +} + +func findMark(mark, buf []byte, startPos, maxPos int) int { + endPos := len(buf) + if endPos > maxPos { + endPos = maxPos + } + + // XXX: bytes.Index() uses a naive search, which kind of sucks. + pos := bytes.Index(buf[startPos:endPos], mark) + if pos == -1 { + return -1 + } + + // Return the index relative to the start of the slice. + return pos + startPos +} + +func makePad(min, max int64) ([]byte, error) { + if max < min { + panic(fmt.Sprintf("makePad: min > max (%d, %d)", min, max)) + } + + padRange := int64((max + 1) - min) + padLen, err := rand.Int(rand.Reader, big.NewInt(padRange)) + if err != nil { + return nil, err + } + pad := make([]byte, padLen.Int64()+min) + _, err = rand.Read(pad) + if err != nil { + return nil, err + } + + return pad, err +} + +/* vim :set ts=4 sw=4 sts=4 noet : */ |