From 84b290a9b9f7bfc316ca659f495e84650d0ee780 Mon Sep 17 00:00:00 2001 From: Brandon Wiley Date: Wed, 17 Aug 2016 15:39:21 -0500 Subject: Implementing connection pool handling semantics specified in Pluggable Transports 2.0 Specification, Draft 1 --- modes/transparent_udp/transparent_udp.go | 226 ++++++++++++++++++++----------- 1 file changed, 146 insertions(+), 80 deletions(-) diff --git a/modes/transparent_udp/transparent_udp.go b/modes/transparent_udp/transparent_udp.go index 500cd55..b4eb987 100644 --- a/modes/transparent_udp/transparent_udp.go +++ b/modes/transparent_udp/transparent_udp.go @@ -30,14 +30,15 @@ package transparent_udp import ( - "fmt" "io" + "fmt" golog "log" "net" "net/url" "strconv" "strings" - "sync" + "bytes" + "encoding/binary" "golang.org/x/net/proxy" @@ -56,6 +57,17 @@ const ( var stateDir string +type ConnState struct { + Conn *net.Conn + Waiting bool +} + +func NewConnState() ConnState { + return ConnState{nil, true} +} + +type ConnTracker map[string]ConnState + func ClientSetup(termMon *termmon.TermMonitor, target string) bool { methodNames := [...]string{"obfs2"} var ptClientProxy *url.URL = nil @@ -79,7 +91,7 @@ func ClientSetup(termMon *termmon.TermMonitor, target string) bool { fmt.Println("Error resolving address", socksAddr) } - fmt.Println("Listening ", socksAddr) + fmt.Println("@@@ Listening ", name, socksAddr) ln, err := net.ListenUDP("udp", udpAddr) if err != nil { log.Errorf("failed to listen %s %s", name, err.Error()) @@ -94,27 +106,82 @@ func ClientSetup(termMon *termmon.TermMonitor, target string) bool { return true } -func clientHandler(target string, termMon *termmon.TermMonitor, f base.ClientFactory, conn net.Conn, proxyURI *url.URL) { +func clientHandler(target string, termMon *termmon.TermMonitor, f base.ClientFactory, conn *net.UDPConn, proxyURI *url.URL) { + var length16 uint16 + defer conn.Close() termMon.OnHandlerStart() defer termMon.OnHandlerFinish() - fmt.Println("handling...") + fmt.Println("@@@ handling...") + + tracker := make(ConnTracker) name := f.Transport().Name() fmt.Println("Transport is", name) - // Deal with arguments. - args, err := f.ParseArgs(&pt.Args{}) - if err != nil { - fmt.Println("Invalid arguments") - log.Errorf("%s(%s) - invalid arguments: %s", name, target, err) - return + buf := make([]byte, 1024) + + // Receive UDP packets and forward them over transport connections forever + for { + n, addr, err := conn.ReadFromUDP(buf) + fmt.Println("Received ",string(buf[0:n]), " from ",addr) + + if err != nil { + fmt.Println("Error: ",err) + } + + fmt.Println(tracker) + + if state, ok := tracker[addr.String()]; ok { + // There is an open transport connection, or a connection attempt is in progress. + + if state.Waiting { + // The connection attempt is in progress. + // Drop the packet. + fmt.Println("recv: waiting") + } else { + // There is an open transport connection. + // Send the packet through the transport. + fmt.Println("recv: write") + length16 = uint16(n) + lengthBuf := new(bytes.Buffer) + err = binary.Write(lengthBuf, binary.LittleEndian, length16) + if err != nil { + fmt.Println("binary.Write failed:", err) + } else { + fmt.Println("writing...") + fmt.Println(length16) + fmt.Println(lengthBuf.Bytes()) + (*state.Conn).Write(lengthBuf.Bytes()) + (*state.Conn).Write(buf) + } + } + } else { + // There is not an open transport connection and a connection attempt is not in progress. + // Open a transport connection. + + fmt.Println("Opening connection to ", target) + + openConnection(&tracker, addr.String(), target, termMon, f, proxyURI) + + // Drop the packet. + fmt.Println("recv: Open") + } } +} +func openConnection(tracker *ConnTracker, addr string, target string, termMon *termmon.TermMonitor, f base.ClientFactory, proxyURI *url.URL) { fmt.Println("Making dialer...") + newConn := NewConnState() + (*tracker)[addr]=newConn + + go dialConn(tracker, addr, target, f, proxyURI) +} + +func dialConn(tracker *ConnTracker, addr string, target string, f base.ClientFactory, proxyURI *url.URL) { // Obtain the proxy dialer if any, and create the outgoing TCP connection. dialFn := proxy.Direct.Dial if proxyURI != nil { @@ -123,41 +190,39 @@ func clientHandler(target string, termMon *termmon.TermMonitor, f base.ClientFac // This should basically never happen, since config protocol // verifies this. fmt.Println("failed to obtain dialer", proxyURI, proxy.Direct) - log.Errorf("%s(%s) - failed to obtain proxy dialer: %s", name, target, log.ElideError(err)) + log.Errorf("(%s) - failed to obtain proxy dialer: %s", target, log.ElideError(err)) return } dialFn = dialer.Dial } - fmt.Println("Dialing...") + fmt.Println("Dialing....") - remote, err := f.Dial("tcp", target, dialFn, args) + // Deal with arguments. + args, err := f.ParseArgs(&pt.Args{}) if err != nil { - fmt.Println("outgoing connection failed") - log.Errorf("%s(%s) - outgoing connection failed: %s", name, target, log.ElideError(err)) + fmt.Println("Invalid arguments") + log.Errorf("(%s) - invalid arguments: %s", target, err) + delete(*tracker, addr) return } - defer remote.Close() - fmt.Println("copying...") - - if err = copyLoopUDP(conn, remote); err != nil { - log.Warnf("%s(%s) - closed connection: %s", name, target, log.ElideError(err)) - } else { - log.Infof("%s(%s) - closed connection", name, target) + fmt.Println("Dialing ", target) + remote, err := f.Dial("tcp", target, dialFn, args) + if err != nil { + fmt.Println("outgoing connection failed", err) + log.Errorf("(%s) - outgoing connection failed: %s", target, log.ElideError(err)) + fmt.Println("Failed") + delete(*tracker, addr) + return } - fmt.Println("done") + fmt.Println("Success") - return + (*tracker)[addr]=ConnState{&remote, false} } -func ServerSetup(termMon *termmon.TermMonitor, bindaddrString string) bool { - ptServerInfo, err := pt.ServerSetup(transports.Transports()) - if err != nil { - golog.Fatal(err) - } - +func ServerSetup(termMon *termmon.TermMonitor, bindaddrString string, target string) bool { fmt.Println("ServerSetup") bindaddrs, _ := getServerBindaddrs(bindaddrString) @@ -183,7 +248,7 @@ func ServerSetup(termMon *termmon.TermMonitor, bindaddrString string) bool { continue } - go serverAcceptLoop(termMon, f, ln, &ptServerInfo) + go serverAcceptLoop(termMon, f, ln, target) log.Infof("%s - registered listener: %s", name, log.ElideAddr(ln.Addr().String())) } @@ -257,7 +322,7 @@ func parsePort(portStr string) (int, error) { return int(port), err } -func serverAcceptLoop(termMon *termmon.TermMonitor, f base.ServerFactory, ln net.Listener, info *pt.ServerInfo) error { +func serverAcceptLoop(termMon *termmon.TermMonitor, f base.ServerFactory, ln net.Listener, target string) error { defer ln.Close() for { conn, err := ln.Accept() @@ -268,76 +333,77 @@ func serverAcceptLoop(termMon *termmon.TermMonitor, f base.ServerFactory, ln net } continue } - go serverHandler(termMon, f, conn, info) + go serverHandler(termMon, f, conn, target) } } -func serverHandler(termMon *termmon.TermMonitor, f base.ServerFactory, conn net.Conn, info *pt.ServerInfo) { +func serverHandler(termMon *termmon.TermMonitor, f base.ServerFactory, conn net.Conn, target string) { + var length16 uint16 + defer conn.Close() termMon.OnHandlerStart() defer termMon.OnHandlerFinish() name := f.Transport().Name() addrStr := log.ElideAddr(conn.RemoteAddr().String()) - fmt.Println("handling", name) + fmt.Println("### handling", name) log.Infof("%s(%s) - new connection", name, addrStr) // Instantiate the server transport method and handshake. remote, err := f.WrapConn(conn) if err != nil { - fmt.Println("handshake failed") + fmt.Println("handshake failed", err) log.Warnf("%s(%s) - handshake failed: %s", name, addrStr, log.ElideError(err)) return } - // Connect to the orport. - orConn, err := pt.DialOr(info, conn.RemoteAddr().String(), name) + serverAddr, err := net.ResolveUDPAddr("udp",target) if err != nil { - fmt.Println("OR conn failed", info, conn.RemoteAddr(), name) - log.Errorf("%s(%s) - failed to connect to ORPort: %s", name, addrStr, log.ElideError(err)) - return + golog.Fatal(err) } - defer orConn.Close() - if err = copyLoopUDP(orConn, remote); err != nil { - log.Warnf("%s(%s) - closed connection: %s", name, addrStr, log.ElideError(err)) - } else { - log.Infof("%s(%s) - closed connection", name, addrStr) + localAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + golog.Fatal(err) } - return -} - -func copyLoopUDP(a net.Conn, b net.Conn) error { - // Note: b is always the pt connection. a is the UDP connection. - errChan := make(chan error, 2) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - defer b.Close() - defer a.Close() - _, err := io.Copy(b, a) - errChan <- err - }() - go func() { - defer wg.Done() - defer a.Close() - defer b.Close() - _, err := io.Copy(a, b) - errChan <- err - }() - - // Wait for both upstream and downstream to close. Since one side - // terminating closes the other, the second error in the channel will be - // something like EINVAL (though io.Copy() will swallow EOF), so only the - // first error is returned. - wg.Wait() - if len(errChan) > 0 { - return <-errChan + dest, err := net.DialUDP("udp", localAddr, serverAddr) + if err != nil { + golog.Fatal(err) } - return nil + fmt.Println("pumping") + + defer dest.Close() + + lengthBuffer := make([]byte, 2) + + for { + fmt.Println("reading...") + // Read the incoming connection into the buffer. + readLen, err := io.ReadFull(remote, lengthBuffer) + if err != nil { + fmt.Println("read error") + break + } + + fmt.Println(readLen) + + err = binary.Read(bytes.NewReader(lengthBuffer), binary.LittleEndian, &length16) + if err != nil { + fmt.Println("deserialization error") + return + } + + fmt.Println("reading data") + + readBuffer := make([]byte, length16) + readLen, err = io.ReadFull(remote, readBuffer) + if err != nil { + fmt.Println("read error") + break + } + + dest.Write(readBuffer) + } } -- cgit v1.2.3