summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--shapeshifter.go32
1 files changed, 28 insertions, 4 deletions
diff --git a/shapeshifter.go b/shapeshifter.go
index faab05e..f744808 100644
--- a/shapeshifter.go
+++ b/shapeshifter.go
@@ -1,12 +1,14 @@
package shapeshifter
import (
+ "encoding/base64"
"fmt"
"io"
"log"
"net"
"sync"
+ "github.com/OperatorFoundation/obfs4/common/ntor"
"github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4"
)
@@ -76,6 +78,10 @@ func (ss ShapeShifter) clientHandler(conn net.Conn) {
defer conn.Close()
transport := obfs4.NewObfs4Client(ss.Cert, ss.IatMode)
+ if transport == nil {
+ ss.sendError("Can not create an obfs4 client (cert: %s, iat-mode: %d)", ss.Cert, ss.IatMode)
+ return
+ }
remote, err := transport.Dial(ss.Target)
if err != nil {
ss.sendError("outgoing connection failed %s: %v", ss.Target, err)
@@ -135,10 +141,7 @@ func (ss *ShapeShifter) checkOptions() error {
if ss.SocksAddr == "" {
ss.SocksAddr = "127.0.0.1:0"
}
- if ss.Cert == "" {
- return fmt.Errorf("obfs4 transport missing cert argument")
- }
- return nil
+ return isCertValid(ss.Cert)
}
func (ss *ShapeShifter) sendError(format string, a ...interface{}) {
@@ -151,3 +154,24 @@ func (ss *ShapeShifter) sendError(format string, a ...interface{}) {
log.Printf(format, a...)
}
}
+
+func isCertValid(cert string) error {
+ // copied from github.com/OperatorFoundation/shapeshifter-transports/transports/obfs4/statefile.go
+ const certSuffix = "=="
+ const certLength = ntor.NodeIDLength + ntor.PublicKeyLength
+
+ if cert == "" {
+ return fmt.Errorf("obfs4 transport missing cert argument")
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(cert + certSuffix)
+ if err != nil {
+ return fmt.Errorf("failed to decode cert: %s", err)
+ }
+
+ if len(decoded) != certLength {
+ return fmt.Errorf("cert length %d is invalid", len(decoded))
+ }
+
+ return nil
+}