summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuben Pollan <meskio@sindominio.net>2019-10-22 13:28:04 +0200
committerRuben Pollan <meskio@sindominio.net>2019-10-22 13:28:04 +0200
commit5ec01f9720e4deffe463671e4578cabb7ac16b04 (patch)
tree041059e0d0a3a224fe44b31b2eec24415c7357a2
parentf602ba600f5d3b9444b4072b7cd0b27df14be8b8 (diff)
Check if the certificate is valid
obfs4 doesn't check if the cert is valid, just returns a nil transport if is invalid. - Resolves: #1
-rw-r--r--shapeshifter.go32
1 files changed, 28 insertions, 4 deletions
diff --git a/shapeshifter.go b/shapeshifter.go
index faab05e..f744808 100644
--- a/shapeshifter.go
+++ b/shapeshifter.go
@@ -1,12 +1,14 @@
package shapeshifter
import (
+ "encoding/base64"
"fmt"
"io"
"log"
"net"
"sync"
+ "github.com/OperatorFoundation/obfs4/common/ntor"
"github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4"
)
@@ -76,6 +78,10 @@ func (ss ShapeShifter) clientHandler(conn net.Conn) {
defer conn.Close()
transport := obfs4.NewObfs4Client(ss.Cert, ss.IatMode)
+ if transport == nil {
+ ss.sendError("Can not create an obfs4 client (cert: %s, iat-mode: %d)", ss.Cert, ss.IatMode)
+ return
+ }
remote, err := transport.Dial(ss.Target)
if err != nil {
ss.sendError("outgoing connection failed %s: %v", ss.Target, err)
@@ -135,10 +141,7 @@ func (ss *ShapeShifter) checkOptions() error {
if ss.SocksAddr == "" {
ss.SocksAddr = "127.0.0.1:0"
}
- if ss.Cert == "" {
- return fmt.Errorf("obfs4 transport missing cert argument")
- }
- return nil
+ return isCertValid(ss.Cert)
}
func (ss *ShapeShifter) sendError(format string, a ...interface{}) {
@@ -151,3 +154,24 @@ func (ss *ShapeShifter) sendError(format string, a ...interface{}) {
log.Printf(format, a...)
}
}
+
+func isCertValid(cert string) error {
+ // copied from github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4/statefile.go
+ const certSuffix = "=="
+ const certLength = ntor.NodeIDLength + ntor.PublicKeyLength
+
+ if cert == "" {
+ return fmt.Errorf("obfs4 transport missing cert argument")
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(cert + certSuffix)
+ if err != nil {
+ return fmt.Errorf("failed to decode cert: %s", err)
+ }
+
+ if len(decoded) != certLength {
+ return fmt.Errorf("cert length %d is invalid", len(decoded))
+ }
+
+ return nil
+}