diff options
Diffstat (limited to 'obfsvpn_test.go')
-rw-r--r-- | obfsvpn_test.go | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/obfsvpn_test.go b/obfsvpn_test.go new file mode 100644 index 0000000..c627500 --- /dev/null +++ b/obfsvpn_test.go @@ -0,0 +1,99 @@ +package obfsvpn_test + +import ( + "context" + "io" + "testing" + "time" + + "gitlab.com/yawning/obfs4.git/common/ntor" + + "0xacab.org/leap/obfsvpn" +) + +func TestRoundTrip(t *testing.T) { + pair, err := ntor.NewKeypair(false) + if err != nil { + t.Fatalf("error generating keys: %v", err) + } + nodeID, err := ntor.NewNodeID(make([]byte, ntor.NodeIDLength)) + if err != nil { + t.Fatalf("error creating node ID: %v", err) + } + lc := obfsvpn.ListenConfig{ + NodeID: nodeID, + PrivateKey: pair.Private(), + StateDir: t.TempDir(), + } + ln, err := lc.Listen(context.Background(), "tcp", ":0") + if err != nil { + t.Fatalf("error listening for incoming connection: %v", err) + } + + const ( + clientSend = `Though they broke my legs, they gave me a crutch to walk.` + serverReply = `Her Majesty's a pretty nice girl, but she's pretty much obsolete.` + ) + + errs := make(chan error) + serverRecv := make([]byte, len(clientSend)) + go func() { + conn, err := ln.Accept() + if err != nil { + errs <- err + return + } + _, err = conn.Read(serverRecv) + if err != nil { + errs <- err + return + } + _, err = io.WriteString(conn, serverReply) + if err != nil { + errs <- err + return + } + }() + + select { + case err := <-errs: + t.Fatalf("error accepting connection: %v", err) + default: + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + d := obfsvpn.Dialer{ + NodeID: nodeID, + PublicKey: pair.Public(), + } + conn, err := d.Dial(ctx, "tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("error dialing connection: %v", err) + } + _, err = io.WriteString(conn, clientSend) + if err != nil { + t.Fatalf("error writing client side: %v", err) + } + select { + case err := <-errs: + t.Fatalf("error reading server side: %v", err) + default: + } + clientRecv := make([]byte, len(serverReply)) + _, err = conn.Read(clientRecv) + if err != nil { + t.Fatalf("error reading client side: %v", err) + } + select { + case err := <-errs: + t.Fatalf("error writing server side: %v", err) + default: + } + + if s := string(clientRecv); s != serverReply { + t.Fatalf("wrong response from server: want=%q, got=%q", serverReply, s) + } + if s := string(serverRecv); s != clientSend { + t.Fatalf("wrong request from client: want=%q, got=%q", clientSend, s) + } +} |