diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index ff6552e2d8..fdf0d9a37f 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -11,6 +11,7 @@ import ( "net/netip" "net/url" "strconv" + "strings" "sync/atomic" "testing" "time" @@ -39,8 +40,21 @@ var ( Client2ID = uuid.MustParse("00000000-0000-0000-0000-000000000002") ) -// StartServerBasic creates a coordinator and DERP server. -func StartServerBasic(t *testing.T, logger slog.Logger, listenAddr string) { +type ServerOptions struct { + // FailUpgradeDERP will make the DERP server fail to handle the initial DERP + // upgrade in a way that causes the client to fallback to + // DERP-over-WebSocket fallback automatically. + // Incompatible with DERPWebsocketOnly. + FailUpgradeDERP bool + // DERPWebsocketOnly will make the DERP server only accept WebSocket + // connections. If a DERP request is received that is not using WebSocket + // fallback, the test will fail. + // Incompatible with FailUpgradeDERP. + DERPWebsocketOnly bool +} + +//nolint:revive +func (o ServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { coord := tailnet.NewCoordinator(logger) var coordPtr atomic.Pointer[tailnet.Coordinator] coordPtr.Store(&coord) @@ -69,15 +83,38 @@ func StartServerBasic(t *testing.T, logger slog.Logger, listenAddr string) { tracing.StatusWriterMiddleware, httpmw.Logger(logger), ) + r.Route("/derp", func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { logger.Info(r.Context(), "start derp request", slog.F("path", r.URL.Path), slog.F("remote_ip", r.RemoteAddr)) + + upgrade := strings.ToLower(r.Header.Get("Upgrade")) + if upgrade != "derp" && upgrade != "websocket" { + http.Error(w, "invalid DERP upgrade header", http.StatusBadRequest) + t.Errorf("invalid DERP upgrade header: %s", upgrade) + return + } + + if o.FailUpgradeDERP && upgrade == "derp" { + // 4xx status codes will cause the client to fallback to + // DERP-over-WebSocket. + http.Error(w, "test derp upgrade failure", http.StatusBadRequest) + return + } + if o.DERPWebsocketOnly && upgrade != "websocket" { + logger.Error(r.Context(), "non-websocket DERP request received", slog.F("path", r.URL.Path), slog.F("remote_ip", r.RemoteAddr)) + http.Error(w, "non-websocket DERP request received", http.StatusBadRequest) + t.Error("non-websocket DERP request received") + return + } + derpHandler.ServeHTTP(w, r) }) r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) }) + r.Get("/api/v2/workspaceagents/{id}/coordinate", func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() idStr := chi.URLParamFromCtx(ctx, "id") @@ -116,28 +153,44 @@ func StartServerBasic(t *testing.T, logger slog.Logger, listenAddr string) { } }) - // We have a custom listen address. - srv := http.Server{ - Addr: listenAddr, - Handler: r, - ReadTimeout: 10 * time.Second, - } - serveDone := make(chan struct{}) - go func() { - defer close(serveDone) - err := srv.ListenAndServe() - if err != nil && !xerrors.Is(err, http.ErrServerClosed) { - t.Error("HTTP server error:", err) - } - }() - t.Cleanup(func() { - _ = srv.Close() - <-serveDone + return r +} + +// StartClientDERP creates a client connection to the server for coordination +// and creates a tailnet.Conn which will only use DERP to connect to the peer. +func StartClientDERP(t *testing.T, logger slog.Logger, serverURL *url.URL, myID, peerID uuid.UUID) *tailnet.Conn { + return startClientOptions(t, logger, serverURL, myID, peerID, &tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IPFromUUID(myID), 128)}, + DERPMap: basicDERPMap(t, serverURL), + BlockEndpoints: true, + Logger: logger, + DERPForceWebSockets: false, + // These tests don't have internet connection, so we need to force + // magicsock to do anything. + ForceNetworkUp: true, }) } -// StartClientBasic creates a client connection to the server. -func StartClientBasic(t *testing.T, logger slog.Logger, serverURL *url.URL, myID uuid.UUID, peerID uuid.UUID) *tailnet.Conn { +// StartClientDERPWebSockets does the same thing as StartClientDERP but will +// only use DERP WebSocket fallback. +func StartClientDERPWebSockets(t *testing.T, logger slog.Logger, serverURL *url.URL, myID, peerID uuid.UUID) *tailnet.Conn { + return startClientOptions(t, logger, serverURL, myID, peerID, &tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IPFromUUID(myID), 128)}, + DERPMap: basicDERPMap(t, serverURL), + BlockEndpoints: true, + Logger: logger, + DERPForceWebSockets: true, + // These tests don't have internet connection, so we need to force + // magicsock to do anything. + ForceNetworkUp: true, + }) +} + +type ClientStarter struct { + Options *tailnet.Options +} + +func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, myID, peerID uuid.UUID, options *tailnet.Options) *tailnet.Conn { u, err := serverURL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", myID.String())) require.NoError(t, err) //nolint:bodyclose @@ -156,15 +209,7 @@ func StartClientBasic(t *testing.T, logger slog.Logger, serverURL *url.URL, myID coord, err := client.Coordinate(context.Background()) require.NoError(t, err) - conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IPFromUUID(myID), 128)}, - DERPMap: basicDERPMap(t, serverURL), - BlockEndpoints: true, - Logger: logger, - // These tests don't have internet connection, so we need to force - // magicsock to do anything. - ForceNetworkUp: true, - }) + conn, err := tailnet.NewConn(options) require.NoError(t, err) t.Cleanup(func() { _ = conn.Close() diff --git a/tailnet/test/integration/integration_test.go b/tailnet/test/integration/integration_test.go index 76b57fecae..dcd64b9343 100644 --- a/tailnet/test/integration/integration_test.go +++ b/tailnet/test/integration/integration_test.go @@ -20,6 +20,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -68,19 +69,48 @@ func TestMain(m *testing.M) { var topologies = []integration.TestTopology{ { + // Test that DERP over loopback works. Name: "BasicLoopbackDERP", SetupNetworking: integration.SetupNetworkingLoopback, - StartServer: integration.StartServerBasic, - StartClient: integration.StartClientBasic, + ServerOptions: integration.ServerOptions{}, + StartClient: integration.StartClientDERP, RunTests: integration.TestSuite, }, { + // Test that DERP over "easy" NAT works. The server, client 1 and client + // 2 are on different networks with a shared router, and the router + // masquerades the traffic. Name: "EasyNATDERP", SetupNetworking: integration.SetupNetworkingEasyNAT, - StartServer: integration.StartServerBasic, - StartClient: integration.StartClientBasic, + ServerOptions: integration.ServerOptions{}, + StartClient: integration.StartClientDERP, RunTests: integration.TestSuite, }, + { + // Test that DERP over WebSocket (as well as DERPForceWebSockets works). + // This does not test the actual DERP failure detection code and + // automatic fallback. + Name: "DERPForceWebSockets", + SetupNetworking: integration.SetupNetworkingEasyNAT, + ServerOptions: integration.ServerOptions{ + FailUpgradeDERP: false, + DERPWebsocketOnly: true, + }, + StartClient: integration.StartClientDERPWebSockets, + RunTests: integration.TestSuite, + }, + { + // Test that falling back to DERP over WebSocket works. + Name: "DERPFallbackWebSockets", + SetupNetworking: integration.SetupNetworkingEasyNAT, + ServerOptions: integration.ServerOptions{ + FailUpgradeDERP: true, + DERPWebsocketOnly: false, + }, + // Use a basic client that will try `Upgrade: derp` first. + StartClient: integration.StartClientDERP, + RunTests: integration.TestSuite, + }, } //nolint:paralleltest,tparallel @@ -101,19 +131,17 @@ func TestIntegration(t *testing.T) { networking := topo.SetupNetworking(t, log) // Fork the three child processes. - serverErrCh, closeServer := startServerSubprocess(t, topo.Name, networking) + closeServer := startServerSubprocess(t, topo.Name, networking) // client1 runs the tests. client1ErrCh, _ := startClientSubprocess(t, topo.Name, networking, 1) - client2ErrCh, closeClient2 := startClientSubprocess(t, topo.Name, networking, 2) + _, closeClient2 := startClientSubprocess(t, topo.Name, networking, 2) // Wait for client1 to exit. require.NoError(t, <-client1ErrCh, "client 1 exited") // Close client2 and the server. - closeClient2() - require.NoError(t, <-client2ErrCh, "client 2 exited") - closeServer() - require.NoError(t, <-serverErrCh, "server exited") + require.NoError(t, closeClient2(), "client 2 exited") + require.NoError(t, closeServer(), "server exited") }) } } @@ -138,15 +166,32 @@ func handleTestSubprocess(t *testing.T) { //nolint:parralleltest t.Run(testName, func(t *testing.T) { - log := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) switch *role { case "server": - log = log.Named("server") - topo.StartServer(t, log, *serverListenAddr) + logger = logger.Named("server") + + srv := http.Server{ + Addr: *serverListenAddr, + Handler: topo.ServerOptions.Router(t, logger), + ReadTimeout: 10 * time.Second, + } + serveDone := make(chan struct{}) + go func() { + defer close(serveDone) + err := srv.ListenAndServe() + if err != nil && !xerrors.Is(err, http.ErrServerClosed) { + t.Error("HTTP server error:", err) + } + }() + t.Cleanup(func() { + _ = srv.Close() + <-serveDone + }) // no exit case "client": - log = log.Named(*clientName) + logger = logger.Named(*clientName) serverURL, err := url.Parse(*clientServerURL) require.NoErrorf(t, err, "parse server url %q", *clientServerURL) myID, err := uuid.Parse(*clientMyID) @@ -156,7 +201,7 @@ func handleTestSubprocess(t *testing.T) { waitForServerAvailable(t, serverURL) - conn := topo.StartClient(t, log, serverURL, myID, peerID) + conn := topo.StartClient(t, logger, serverURL, myID, peerID) if *clientRunTests { // Wait for connectivity. @@ -165,7 +210,7 @@ func handleTestSubprocess(t *testing.T) { t.Fatalf("peer %v did not become reachable", peerIP) } - topo.RunTests(t, log, serverURL, myID, peerID, conn) + topo.RunTests(t, logger, serverURL, myID, peerID, conn) // then exit return } @@ -206,16 +251,17 @@ func waitForServerAvailable(t *testing.T, serverURL *url.URL) { t.Fatalf("server did not become available after %v", timeout) } -func startServerSubprocess(t *testing.T, topologyName string, networking integration.TestNetworking) (<-chan error, func()) { - return startSubprocess(t, "server", networking.ProcessServer.NetNS, []string{ +func startServerSubprocess(t *testing.T, topologyName string, networking integration.TestNetworking) func() error { + _, closeFn := startSubprocess(t, "server", networking.ProcessServer.NetNS, []string{ "--subprocess", "--test-name=" + topologyName, "--role=server", "--server-listen-addr=" + networking.ServerListenAddr, }) + return closeFn } -func startClientSubprocess(t *testing.T, topologyName string, networking integration.TestNetworking, clientNumber int) (<-chan error, func()) { +func startClientSubprocess(t *testing.T, topologyName string, networking integration.TestNetworking, clientNumber int) (<-chan error, func() error) { require.True(t, clientNumber == 1 || clientNumber == 2) var ( @@ -247,7 +293,13 @@ func startClientSubprocess(t *testing.T, topologyName string, networking integra return startSubprocess(t, clientName, netNS, flags) } -func startSubprocess(t *testing.T, processName string, netNS *os.File, flags []string) (<-chan error, func()) { +// startSubprocess starts a subprocess with the given flags and returns a +// channel that will receive the error when the subprocess exits. The returned +// function can be used to close the subprocess. +// +// Do not call close then wait on the channel. Use the returned value from the +// function instead in this case. +func startSubprocess(t *testing.T, processName string, netNS *os.File, flags []string) (<-chan error, func() error) { name := os.Args[0] // Always use verbose mode since it gets piped to the parent test anyways. args := append(os.Args[1:], append([]string{"-test.v=true"}, flags...)...) @@ -289,15 +341,15 @@ func startSubprocess(t *testing.T, processName string, netNS *os.File, flags []s close(waitErr) }() - closeFn := func() { + closeFn := func() error { _ = cmd.Process.Signal(syscall.SIGTERM) select { case <-time.After(5 * time.Second): _ = cmd.Process.Kill() - case <-waitErr: - return + case err := <-waitErr: + return err } - <-waitErr + return <-waitErr } t.Cleanup(func() { @@ -310,7 +362,7 @@ func startSubprocess(t *testing.T, processName string, netNS *os.File, flags []s default: } - closeFn() + _ = closeFn() }) return waitErr, closeFn @@ -338,6 +390,11 @@ func (w *testWriter) Write(p []byte) (n int, err error) { // then it's a test result line. We want to capture it and log it later. trimmed := strings.TrimSpace(s) if strings.HasPrefix(trimmed, "--- PASS") || strings.HasPrefix(trimmed, "--- FAIL") || trimmed == "PASS" || trimmed == "FAIL" { + // Also fail the test if we see a FAIL line. + if strings.Contains(trimmed, "FAIL") { + w.t.Errorf("subprocess logged test failure: %s: \t%s", w.name, s) + } + w.capturedLines = append(w.capturedLines, s) continue } diff --git a/tailnet/test/integration/network.go b/tailnet/test/integration/network.go index 604d7827cd..f36ac63745 100644 --- a/tailnet/test/integration/network.go +++ b/tailnet/test/integration/network.go @@ -29,9 +29,9 @@ type TestTopology struct { // a network namespace shared for all tests. SetupNetworking func(t *testing.T, logger slog.Logger) TestNetworking - // StartServer gets called in the server subprocess. It's expected to start - // the coordinator server in the background and return. - StartServer func(t *testing.T, logger slog.Logger, listenAddr string) + // ServerOptions is the configuration for the server. It's passed to the + // server process. + ServerOptions ServerOptions // StartClient gets called in each client subprocess. It's expected to // create the tailnet.Conn and ensure connectivity to it's peer. StartClient func(t *testing.T, logger slog.Logger, serverURL *url.URL, myID uuid.UUID, peerID uuid.UUID) *tailnet.Conn