summaryrefslogtreecommitdiff
path: root/obfs4proxy/obfs4proxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'obfs4proxy/obfs4proxy.go')
-rw-r--r--obfs4proxy/obfs4proxy.go72
1 files changed, 21 insertions, 51 deletions
diff --git a/obfs4proxy/obfs4proxy.go b/obfs4proxy/obfs4proxy.go
index b27d75d..9b452ac 100644
--- a/obfs4proxy/obfs4proxy.go
+++ b/obfs4proxy/obfs4proxy.go
@@ -38,7 +38,6 @@ import (
"net"
"net/url"
"os"
- "os/signal"
"path"
"sync"
"syscall"
@@ -60,7 +59,7 @@ const (
var enableLogging bool
var unsafeLogging bool
var stateDir string
-var handlerChan chan int
+var termMon *termMonitor
// DialFn is a function pointer to a function that matches the net.Dialer.Dial
// interface.
@@ -176,10 +175,8 @@ func clientAcceptLoop(f base.ClientFactory, ln *pt.SocksListener, proxyURI *url.
func clientHandler(f base.ClientFactory, conn *pt.SocksConn, proxyURI *url.URL) {
defer conn.Close()
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
+ termMon.onHandlerStart()
+ defer termMon.onHandlerFinish()
name := f.Transport().Name()
addrStr := elideAddr(conn.Req.Target)
@@ -298,10 +295,8 @@ func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo
func serverHandler(f base.ServerFactory, conn net.Conn, info *pt.ServerInfo) {
defer conn.Close()
- handlerChan <- 1
- defer func() {
- handlerChan <- -1
- }()
+ termMon.onHandlerStart()
+ defer termMon.onHandlerFinish()
name := f.Transport().Name()
addrStr := elideAddr(conn.RemoteAddr().String())
@@ -386,8 +381,8 @@ func getVersion() string {
}
func main() {
- // Initialize parent process monitoring as early as possible.
- pmonErr := initParentMonitor()
+ // Initialize the termination state monitor as soon as possible.
+ termMon = newTermMonitor()
// Handle the command line arguments.
_, execName := path.Split(os.Args[0])
@@ -405,10 +400,8 @@ func main() {
log.Fatalf("[ERROR]: failed to set log level: %s", err)
}
- // Determine if this is a client or server, initialize logging, and finish
- // the pt configuration.
+ // Determine if this is a client or server, initialize the common state.
var ptListeners []net.Listener
- handlerChan = make(chan int)
launched := false
isClient, err := ptIsClient()
if err != nil {
@@ -419,12 +412,10 @@ func main() {
}
if err = ptInitializeLogging(enableLogging); err != nil {
log.Fatalf("[ERROR]: %s - failed to initialize logging", execName)
- } else {
- noticef("%s - launched", getVersion())
- if pmonErr != nil {
- warnf("%s - failed to initialize parent monitor: %s", execName, pmonErr)
- }
}
+ noticef("%s - launched", getVersion())
+
+ // Do the managed pluggable transport protocol configuration.
if isClient {
infof("%s - initializing client transport listeners", execName)
launched, ptListeners = clientSetup()
@@ -444,39 +435,18 @@ func main() {
}()
// At this point, the pt config protocol is finished, and incoming
- // connections will be processed. Per the pt spec, on sane platforms
- // termination is signaled via SIGINT (or SIGTERM), so wait on tor to
- // request a shutdown of some sort.
-
- sigChan := make(chan os.Signal, 1)
- signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
- // Wait for the first SIGINT (close listeners).
- var sig os.Signal
- numHandlers := 0
- for sig == nil {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- if sig == syscall.SIGTERM {
- // SIGTERM causes immediate termination.
- return
- }
- }
+ // connections will be processed. Wait till the parent dies
+ // (immediate exit), a SIGTERM is received (immediate exit),
+ // or a SIGINT is received.
+ if sig := termMon.wait(false); sig == syscall.SIGTERM {
+ return
}
+
+ // Ok, it was the first SIGINT, close all listeners, and wait till,
+ // the parent dies, all the current connections are closed, or either
+ // a SIGINT/SIGTERM is received, and exit.
for _, ln := range ptListeners {
ln.Close()
}
-
- // Wait for the 2nd SIGINT (or a SIGTERM), or for all current sessions to
- // finish.
- sig = nil
- for sig == nil && numHandlers != 0 {
- select {
- case n := <-handlerChan:
- numHandlers += n
- case sig = <-sigChan:
- }
- }
+ termMon.wait(true)
}