summaryrefslogtreecommitdiff
path: root/transports
diff options
context:
space:
mode:
Diffstat (limited to 'transports')
-rw-r--r--transports/base/base.go10
-rw-r--r--transports/obfs2/obfs2.go12
-rw-r--r--transports/obfs3/obfs3.go12
-rw-r--r--transports/obfs4/obfs4.go14
-rw-r--r--transports/scramblesuit/base.go14
5 files changed, 49 insertions, 13 deletions
diff --git a/transports/base/base.go b/transports/base/base.go
index e81ea03..bb0902e 100644
--- a/transports/base/base.go
+++ b/transports/base/base.go
@@ -35,6 +35,8 @@ import (
"git.torproject.org/pluggable-transports/goptlib.git"
)
+type DialFunc func(string, string) (net.Conn, error)
+
// ClientFactory is the interface that defines the factory for creating
// pluggable transport protocol client instances.
type ClientFactory interface {
@@ -48,10 +50,10 @@ type ClientFactory interface {
// generation) to be hidden from third parties.
ParseArgs(args *pt.Args) (interface{}, error)
- // WrapConn wraps the provided net.Conn with a transport protocol
- // implementation, and does whatever is required (eg: handshaking) to get
- // the connection to a point where it is ready to relay data.
- WrapConn(conn net.Conn, args interface{}) (net.Conn, error)
+ // Dial creates an outbound net.Conn, and does whatever is required
+ // (eg: handshaking) to get the connection to the point where it is
+ // ready to relay data.
+ Dial(network, address string, dialFn DialFunc, args interface{}) (net.Conn, error)
}
// ServerFactory is the interface that defines the factory for creating
diff --git a/transports/obfs2/obfs2.go b/transports/obfs2/obfs2.go
index bc2532b..a926141 100644
--- a/transports/obfs2/obfs2.go
+++ b/transports/obfs2/obfs2.go
@@ -108,8 +108,16 @@ func (cf *obfs2ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return nil, validateArgs(args)
}
-func (cf *obfs2ClientFactory) WrapConn(conn net.Conn, args interface{}) (net.Conn, error) {
- return newObfs2ClientConn(conn)
+func (cf *obfs2ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) {
+ conn, err := dialFn(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ if conn, err = newObfs2ClientConn(conn); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return conn, nil
}
type obfs2ServerFactory struct {
diff --git a/transports/obfs3/obfs3.go b/transports/obfs3/obfs3.go
index d215c49..e4c3ba6 100644
--- a/transports/obfs3/obfs3.go
+++ b/transports/obfs3/obfs3.go
@@ -92,8 +92,16 @@ func (cf *obfs3ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return nil, nil
}
-func (cf *obfs3ClientFactory) WrapConn(conn net.Conn, args interface{}) (net.Conn, error) {
- return newObfs3ClientConn(conn)
+func (cf *obfs3ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) {
+ conn, err := dialFn(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ if conn, err = newObfs3ClientConn(conn); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return conn, nil
}
type obfs3ServerFactory struct {
diff --git a/transports/obfs4/obfs4.go b/transports/obfs4/obfs4.go
index 07af9ab..5701535 100644
--- a/transports/obfs4/obfs4.go
+++ b/transports/obfs4/obfs4.go
@@ -204,13 +204,21 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return &obfs4ClientArgs{nodeID, publicKey, sessionKey, iatMode}, nil
}
-func (cf *obfs4ClientFactory) WrapConn(conn net.Conn, args interface{}) (net.Conn, error) {
+func (cf *obfs4ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) {
+ // Validate args before bothering to open connection.
ca, ok := args.(*obfs4ClientArgs)
if !ok {
return nil, fmt.Errorf("invalid argument type for args")
}
-
- return newObfs4ClientConn(conn, ca)
+ conn, err := dialFn(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ if conn, err = newObfs4ClientConn(conn, ca); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return conn, nil
}
type obfs4ServerFactory struct {
diff --git a/transports/scramblesuit/base.go b/transports/scramblesuit/base.go
index 711c046..223d085 100644
--- a/transports/scramblesuit/base.go
+++ b/transports/scramblesuit/base.go
@@ -76,12 +76,22 @@ func (cf *ssClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return newClientArgs(args)
}
-func (cf *ssClientFactory) WrapConn(conn net.Conn, args interface{}) (net.Conn, error) {
+func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) {
+ // Validate args before opening outgoing connection.
ca, ok := args.(*ssClientArgs)
if !ok {
return nil, fmt.Errorf("invalid argument type for args")
}
- return newScrambleSuitClientConn(conn, cf.ticketStore, ca)
+
+ conn, err := dialFn(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ if conn, err = newScrambleSuitClientConn(conn, cf.ticketStore, ca); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return conn, nil
}
var _ base.ClientFactory = (*ssClientFactory)(nil)