From 5ec01f9720e4deffe463671e4578cabb7ac16b04 Mon Sep 17 00:00:00 2001 From: Ruben Pollan Date: Tue, 22 Oct 2019 13:28:04 +0200 Subject: Check if the certificate is valid obfs4 doesn't check if the cert is valid, just returns a nil transport if is invalid. - Resolves: #1 --- shapeshifter.go | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/shapeshifter.go b/shapeshifter.go index faab05e..f744808 100644 --- a/shapeshifter.go +++ b/shapeshifter.go @@ -1,12 +1,14 @@ package shapeshifter import ( + "encoding/base64" "fmt" "io" "log" "net" "sync" + "github.com/OperatorFoundation/obfs4/common/ntor" "github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4" ) @@ -76,6 +78,10 @@ func (ss ShapeShifter) clientHandler(conn net.Conn) { defer conn.Close() transport := obfs4.NewObfs4Client(ss.Cert, ss.IatMode) + if transport == nil { + ss.sendError("Can not create an obfs4 client (cert: %s, iat-mode: %d)", ss.Cert, ss.IatMode) + return + } remote, err := transport.Dial(ss.Target) if err != nil { ss.sendError("outgoing connection failed %s: %v", ss.Target, err) @@ -135,10 +141,7 @@ func (ss *ShapeShifter) checkOptions() error { if ss.SocksAddr == "" { ss.SocksAddr = "127.0.0.1:0" } - if ss.Cert == "" { - return fmt.Errorf("obfs4 transport missing cert argument") - } - return nil + return isCertValid(ss.Cert) } func (ss *ShapeShifter) sendError(format string, a ...interface{}) { @@ -151,3 +154,24 @@ func (ss *ShapeShifter) sendError(format string, a ...interface{}) { log.Printf(format, a...) } } + +func isCertValid(cert string) error { + // copied from github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4/statefile.go + const certSuffix = "==" + const certLength = ntor.NodeIDLength + ntor.PublicKeyLength + + if cert == "" { + return fmt.Errorf("obfs4 transport missing cert argument") + } + + decoded, err := base64.StdEncoding.DecodeString(cert + certSuffix) + if err != nil { + return fmt.Errorf("failed to decode cert: %s", err) + } + + if len(decoded) != certLength { + return fmt.Errorf("cert length %d is invalid", len(decoded)) + } + + return nil +} -- cgit v1.2.3