From 398b795c87387d25c889a3bf700b387cd120520e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Sat, 21 Mar 2015 21:48:36 +0100 Subject: Simplify some err and return logic --- transports/obfs4/handshake_ntor.go | 5 ++- transports/obfs4/obfs4.go | 63 +++++++++++++++++--------------------- transports/obfs4/packet.go | 17 +++++----- transports/obfs4/statefile.go | 6 ++-- 4 files changed, 40 insertions(+), 51 deletions(-) (limited to 'transports/obfs4') diff --git a/transports/obfs4/handshake_ntor.go b/transports/obfs4/handshake_ntor.go index 8dcf0c8..57de460 100644 --- a/transports/obfs4/handshake_ntor.go +++ b/transports/obfs4/handshake_ntor.go @@ -417,10 +417,9 @@ func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int func makePad(padLen int) ([]byte, error) { pad := make([]byte, padLen) - err := csrand.Bytes(pad) - if err != nil { + if err := csrand.Bytes(pad); err != nil { return nil, err } - return pad, err + return pad, nil } diff --git a/transports/obfs4/obfs4.go b/transports/obfs4/obfs4.go index 256f549..07af9ab 100644 --- a/transports/obfs4/obfs4.go +++ b/transports/obfs4/obfs4.go @@ -105,16 +105,15 @@ func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { // ServerFactory returns a new obfs4ServerFactory instance. func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { - var err error - - var st *obfs4ServerState - if st, err = serverStateFromArgs(stateDir, args); err != nil { + st, err := serverStateFromArgs(stateDir, args) + if err != nil { return nil, err } var iatSeed *drbg.Seed if st.iatMode != iatNone { iatSeedSrc := sha256.Sum256(st.drbgSeed.Bytes()[:]) + var err error iatSeed, err = drbg.SeedFromBytes(iatSeedSrc[:]) if err != nil { return nil, err @@ -152,8 +151,6 @@ func (cf *obfs4ClientFactory) Transport() base.Transport { } func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { - var err error - var nodeID *ntor.NodeID var publicKey *ntor.PublicKey @@ -161,8 +158,8 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { // for the Node ID and Public Key. certStr, ok := args.Get(certArg) if ok { - var cert *obfs4ServerCert - if cert, err = serverCertFromString(certStr); err != nil { + cert, err := serverCertFromString(certStr) + if err != nil { return nil, err } nodeID, publicKey = cert.unpack() @@ -173,6 +170,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { if !ok { return nil, fmt.Errorf("missing argument '%s'", nodeIDArg) } + var err error if nodeID, err = ntor.NodeIDFromHex(nodeIDStr); err != nil { return nil, err } @@ -191,8 +189,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { if !ok { return nil, fmt.Errorf("missing argument '%s'", iatArg) } - var iatMode int - iatMode, err = strconv.Atoi(iatStr) + iatMode, err := strconv.Atoi(iatStr) if err != nil || iatMode < iatNone || iatMode > iatParanoid { return nil, fmt.Errorf("invalid iat-mode '%d'", iatMode) } @@ -343,16 +340,15 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto // Consume the server handshake. var hsBuf [maxHandshakeLength]byte for { - var n int - if n, err = conn.Conn.Read(hsBuf[:]); err != nil { + n, err := conn.Conn.Read(hsBuf[:]) + if err != nil { // The Read() could have returned data and an error, but there is // no point in continuing on an EOF or whatever. return err } conn.receiveBuffer.Write(hsBuf[:n]) - var seed []byte - n, seed, err = hs.parseServerHandshake(conn.receiveBuffer.Bytes()) + n, seed, err := hs.parseServerHandshake(conn.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { @@ -369,39 +365,38 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto } } -func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.Keypair) (err error) { +func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.Keypair) error { if !conn.isServer { return fmt.Errorf("serverHandshake called on client connection") } // Generate the server handshake, and arm the base timeout. hs := newServerHandshake(sf.nodeID, sf.identityKey, sessionKey) - if err = conn.Conn.SetDeadline(time.Now().Add(serverHandshakeTimeout)); err != nil { - return + if err := conn.Conn.SetDeadline(time.Now().Add(serverHandshakeTimeout)); err != nil { + return err } // Consume the client handshake. var hsBuf [maxHandshakeLength]byte for { - var n int - if n, err = conn.Conn.Read(hsBuf[:]); err != nil { + n, err := conn.Conn.Read(hsBuf[:]) + if err != nil { // The Read() could have returned data and an error, but there is // no point in continuing on an EOF or whatever. - return + return err } conn.receiveBuffer.Write(hsBuf[:n]) - var seed []byte - seed, err = hs.parseClientHandshake(sf.replayFilter, conn.receiveBuffer.Bytes()) + seed, err := hs.parseClientHandshake(sf.replayFilter, conn.receiveBuffer.Bytes()) if err == ErrMarkNotFoundYet { continue } else if err != nil { - return + return err } conn.receiveBuffer.Reset() - if err = conn.Conn.SetDeadline(time.Time{}); err != nil { - return + if err := conn.Conn.SetDeadline(time.Time{}); err != nil { + return nil } // Use the derived key material to intialize the link crypto. @@ -422,26 +417,24 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor. // handshake_ntor.go. // Generate/send the response. - var blob []byte - blob, err = hs.generateHandshake() + blob, err := hs.generateHandshake() if err != nil { - return + return err } var frameBuf bytes.Buffer - _, err = frameBuf.Write(blob) - if err != nil { - return + if _, err = frameBuf.Write(blob); err != nil { + return err } // Send the PRNG seed as the first packet. - if err = conn.makePacket(&frameBuf, packetTypePrngSeed, sf.lenSeed.Bytes()[:], 0); err != nil { - return + if err := conn.makePacket(&frameBuf, packetTypePrngSeed, sf.lenSeed.Bytes()[:], 0); err != nil { + return err } if _, err = conn.Conn.Write(frameBuf.Bytes()); err != nil { - return + return err } - return + return nil } func (conn *obfs4Conn) Read(b []byte) (n int, err error) { diff --git a/transports/obfs4/packet.go b/transports/obfs4/packet.go index 9865c82..461ad54 100644 --- a/transports/obfs4/packet.go +++ b/transports/obfs4/packet.go @@ -69,7 +69,7 @@ func (e InvalidPayloadLengthError) Error() string { var zeroPadBytes [maxPacketPaddingLength]byte -func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) (err error) { +func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error { var pkt [framing.MaximumFramePayloadLength]byte if len(data)+int(padLen) > maxPacketPayloadLength { @@ -93,22 +93,19 @@ func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLe // Encode the packet in an AEAD frame. var frame [framing.MaximumSegmentLength]byte - frameLen := 0 - frameLen, err = conn.encoder.Encode(frame[:], pkt[:pktLen]) + frameLen, err := conn.encoder.Encode(frame[:], pkt[:pktLen]) if err != nil { // All encoder errors are fatal. - return + return err } - var wrLen int - wrLen, err = w.Write(frame[:frameLen]) + wrLen, err := w.Write(frame[:frameLen]) if err != nil { - return + return err } else if wrLen < frameLen { - err = io.ErrShortWrite - return + return io.ErrShortWrite } - return + return nil } func (conn *obfs4Conn) readPackets() (err error) { diff --git a/transports/obfs4/statefile.go b/transports/obfs4/statefile.go index 0838180..6c34f35 100644 --- a/transports/obfs4/statefile.go +++ b/transports/obfs4/statefile.go @@ -183,7 +183,7 @@ func jsonServerStateFromFile(stateDir string, js *jsonServerState) error { return err } - if err = json.Unmarshal(f, js); err != nil { + if err := json.Unmarshal(f, js); err != nil { return fmt.Errorf("failed to load statefile '%s': %s", fPath, err) } @@ -227,7 +227,7 @@ func newJSONServerState(stateDir string, js *jsonServerState) (err error) { return nil } -func newBridgeFile(stateDir string, st *obfs4ServerState) (err error) { +func newBridgeFile(stateDir string, st *obfs4ServerState) error { const prefix = "# obfs4 torrc client bridge line\n" + "#\n" + "# This file is an automatically generated bridge line based on\n" + @@ -244,7 +244,7 @@ func newBridgeFile(stateDir string, st *obfs4ServerState) (err error) { st.clientString()) tmp := []byte(prefix + bridgeLine) - if err = ioutil.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0600); err != nil { + if err := ioutil.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0600); err != nil { return err } -- cgit v1.2.3