From afd40436f0d5895673fed1c85ad43e3b3e0a04e8 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 25 Nov 2025 14:25:24 +0400 Subject: [PATCH] fix: mock Agent querying OS for listening ports in tests (#20842) fixes https://github.com/coder/internal/issues/1123 We want to tests that ports are not included after they are no longer used, but this isn't safe on the real OS networking stack because there is no way to guarantee a port _won't_ be used. Instead, we introduce an interface and fake implementation for testing. On order to leave the filtering logic in the test path, this PR also does some refactoring. Caching logic is left in the real OS querying implementation and a new test case is added for it in this PR. --- agent/agent.go | 104 +++++----- agent/api.go | 56 +++-- agent/ports_supported.go | 18 +- agent/ports_supported_internal_test.go | 45 ++++ agent/ports_unsupported.go | 12 +- coderd/workspaceagents_test.go | 275 +++++++------------------ 6 files changed, 219 insertions(+), 291 deletions(-) create mode 100644 agent/ports_supported_internal_test.go diff --git a/agent/agent.go b/agent/agent.go index ab882a80ef..238aee9ad1 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -8,6 +8,7 @@ import ( "fmt" "hash/fnv" "io" + "maps" "net" "net/http" "net/netip" @@ -70,16 +71,21 @@ const ( ) type Options struct { - Filesystem afero.Fs - LogDir string - TempDir string - ScriptDataDir string - Client Client - ReconnectingPTYTimeout time.Duration - EnvironmentVariables map[string]string - Logger slog.Logger - IgnorePorts map[int]string - PortCacheDuration time.Duration + Filesystem afero.Fs + LogDir string + TempDir string + ScriptDataDir string + Client Client + ReconnectingPTYTimeout time.Duration + EnvironmentVariables map[string]string + Logger slog.Logger + // IgnorePorts tells the api handler which ports to ignore when + // listing all listening ports. This is helpful to hide ports that + // are used by the agent, that the user does not care about. + IgnorePorts map[int]string + // ListeningPortsGetter is used to get the list of listening ports. Only + // tests should set this. If unset, a default that queries the OS will be used. + ListeningPortsGetter ListeningPortsGetter SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystems []codersdk.AgentSubsystem @@ -137,9 +143,7 @@ func New(options Options) Agent { if options.ServiceBannerRefreshInterval == 0 { options.ServiceBannerRefreshInterval = 2 * time.Minute } - if options.PortCacheDuration == 0 { - options.PortCacheDuration = 1 * time.Second - } + if options.Clock == nil { options.Clock = quartz.NewReal() } @@ -153,30 +157,38 @@ func New(options Options) Agent { options.Execer = agentexec.DefaultExecer } + if options.ListeningPortsGetter == nil { + options.ListeningPortsGetter = &osListeningPortsGetter{ + cacheDuration: 1 * time.Second, + } + } + hardCtx, hardCancel := context.WithCancel(context.Background()) gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ - clock: options.Clock, - tailnetListenPort: options.TailnetListenPort, - reconnectingPTYTimeout: options.ReconnectingPTYTimeout, - logger: options.Logger, - gracefulCtx: gracefulCtx, - gracefulCancel: gracefulCancel, - hardCtx: hardCtx, - hardCancel: hardCancel, - coordDisconnected: make(chan struct{}), - environmentVariables: options.EnvironmentVariables, - client: options.Client, - filesystem: options.Filesystem, - logDir: options.LogDir, - tempDir: options.TempDir, - scriptDataDir: options.ScriptDataDir, - lifecycleUpdate: make(chan struct{}, 1), - lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), - lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, - reportConnectionsUpdate: make(chan struct{}, 1), - ignorePorts: options.IgnorePorts, - portCacheDuration: options.PortCacheDuration, + clock: options.Clock, + tailnetListenPort: options.TailnetListenPort, + reconnectingPTYTimeout: options.ReconnectingPTYTimeout, + logger: options.Logger, + gracefulCtx: gracefulCtx, + gracefulCancel: gracefulCancel, + hardCtx: hardCtx, + hardCancel: hardCancel, + coordDisconnected: make(chan struct{}), + environmentVariables: options.EnvironmentVariables, + client: options.Client, + filesystem: options.Filesystem, + logDir: options.LogDir, + tempDir: options.TempDir, + scriptDataDir: options.ScriptDataDir, + lifecycleUpdate: make(chan struct{}, 1), + lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), + lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, + reportConnectionsUpdate: make(chan struct{}, 1), + listeningPortsHandler: listeningPortsHandler{ + getter: options.ListeningPortsGetter, + ignorePorts: maps.Clone(options.IgnorePorts), + }, reportMetadataInterval: options.ReportMetadataInterval, announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval, sshMaxTimeout: options.SSHMaxTimeout, @@ -202,20 +214,16 @@ func New(options Options) Agent { } type agent struct { - clock quartz.Clock - logger slog.Logger - client Client - tailnetListenPort uint16 - filesystem afero.Fs - logDir string - tempDir string - scriptDataDir string - // ignorePorts tells the api handler which ports to ignore when - // listing all listening ports. This is helpful to hide ports that - // are used by the agent, that the user does not care about. - ignorePorts map[int]string - portCacheDuration time.Duration - subsystems []codersdk.AgentSubsystem + clock quartz.Clock + logger slog.Logger + client Client + tailnetListenPort uint16 + filesystem afero.Fs + logDir string + tempDir string + scriptDataDir string + listeningPortsHandler listeningPortsHandler + subsystems []codersdk.AgentSubsystem reconnectingPTYTimeout time.Duration reconnectingPTYServer *reconnectingpty.Server diff --git a/agent/api.go b/agent/api.go index f417a046c2..4e1da8b028 100644 --- a/agent/api.go +++ b/agent/api.go @@ -2,14 +2,13 @@ package agent import ( "net/http" - "sync" - "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" ) func (a *agent) apiHandler() http.Handler { @@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler { }) }) - // Make a copy to ensure the map is not modified after the handler is - // created. - cpy := make(map[int]string) - for k, b := range a.ignorePorts { - cpy[k] = b - } - - cacheDuration := 1 * time.Second - if a.portCacheDuration > 0 { - cacheDuration = a.portCacheDuration - } - - lp := &listeningPortsHandler{ - ignorePorts: cpy, - cacheDuration: cacheDuration, - } - if a.devcontainers { r.Mount("/api/v0/containers", a.containerAPI.Routes()) } else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil { @@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler { promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger) - r.Get("/api/v0/listening-ports", lp.handler) + r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler) r.Get("/api/v0/netcheck", a.HandleNetcheck) r.Post("/api/v0/list-directory", a.HandleLS) r.Get("/api/v0/read-file", a.HandleReadFile) @@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler { return r } -type listeningPortsHandler struct { - ignorePorts map[int]string - cacheDuration time.Duration +type ListeningPortsGetter interface { + GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) +} - //nolint: unused // used on some but not all platforms - mut sync.Mutex - //nolint: unused // used on some but not all platforms - ports []codersdk.WorkspaceAgentListeningPort - //nolint: unused // used on some but not all platforms - mtime time.Time +type listeningPortsHandler struct { + // In production code, this is set to an osListeningPortsGetter, but it can be overridden for + // testing. + getter ListeningPortsGetter + ignorePorts map[int]string } // handler returns a list of listening ports. This is tested by coderd's // TestWorkspaceAgentListeningPorts test. func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) { - ports, err := lp.getListeningPorts() + ports, err := lp.getter.GetListeningPorts() if err != nil { httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ Message: "Could not scan for listening ports.", @@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request return } + filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports)) + for _, port := range ports { + if port.Port < workspacesdk.AgentMinimumListeningPort { + continue + } + + // Ignore ports that we've been told to ignore. + if _, ok := lp.ignorePorts[int(port.Port)]; ok { + continue + } + filteredPorts = append(filteredPorts, port) + } + httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{ - Ports: ports, + Ports: filteredPorts, }) } diff --git a/agent/ports_supported.go b/agent/ports_supported.go index efa554de98..30df6caf7a 100644 --- a/agent/ports_supported.go +++ b/agent/ports_supported.go @@ -3,16 +3,23 @@ package agent import ( + "sync" "time" "github.com/cakturk/go-netstat/netstat" "golang.org/x/xerrors" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/workspacesdk" ) -func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { +type osListeningPortsGetter struct { + cacheDuration time.Duration + mut sync.Mutex + ports []codersdk.WorkspaceAgentListeningPort + mtime time.Time +} + +func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { lp.mut.Lock() defer lp.mut.Unlock() @@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL seen := make(map[uint16]struct{}, len(tabs)) ports := []codersdk.WorkspaceAgentListeningPort{} for _, tab := range tabs { - if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort { - continue - } - - // Ignore ports that we've been told to ignore. - if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok { + if tab.LocalAddr == nil { continue } diff --git a/agent/ports_supported_internal_test.go b/agent/ports_supported_internal_test.go new file mode 100644 index 0000000000..e16bd8a0c8 --- /dev/null +++ b/agent/ports_supported_internal_test.go @@ -0,0 +1,45 @@ +//go:build linux || (windows && amd64) + +package agent + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOSListeningPortsGetter(t *testing.T) { + t.Parallel() + + uut := &osListeningPortsGetter{ + cacheDuration: 1 * time.Hour, + } + + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer l.Close() + + ports, err := uut.GetListeningPorts() + require.NoError(t, err) + found := false + for _, port := range ports { + // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) + if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) { + found = true + break + } + } + require.True(t, found) + + // check that we cache the ports + err = l.Close() + require.NoError(t, err) + portsNew, err := uut.GetListeningPorts() + require.NoError(t, err) + require.Equal(t, ports, portsNew) + + // note that it's unsafe to try to assert that a port does not exist in the response + // because the OS may reallocate the port very quickly. +} diff --git a/agent/ports_unsupported.go b/agent/ports_unsupported.go index 89ca4f1755..661956a3fc 100644 --- a/agent/ports_unsupported.go +++ b/agent/ports_unsupported.go @@ -2,9 +2,17 @@ package agent -import "github.com/coder/coder/v2/codersdk" +import ( + "time" -func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { + "github.com/coder/coder/v2/codersdk" +) + +type osListeningPortsGetter struct { + cacheDuration time.Duration +} + +func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { // Can't scan for ports on non-linux or non-windows_amd64 systems at the // moment. The UI will not show any "no ports found" message to the user, so // the user won't suspect a thing. diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index e950f97075..6c12f91d37 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -5,12 +5,10 @@ import ( "encoding/json" "fmt" "maps" - "net" "net/http" "os" "path/filepath" - "runtime" - "strconv" + "slices" "strings" "sync" "sync/atomic" @@ -934,17 +932,45 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { require.False(t, p2p) } +type fakeListeningPortsGetter struct { + sync.Mutex + ports []codersdk.WorkspaceAgentListeningPort +} + +func (g *fakeListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { + g.Lock() + defer g.Unlock() + return slices.Clone(g.ports), nil +} + +func (g *fakeListeningPortsGetter) setPorts(ports ...codersdk.WorkspaceAgentListeningPort) { + g.Lock() + defer g.Unlock() + g.ports = slices.Clone(ports) +} + func TestWorkspaceAgentListeningPorts(t *testing.T) { t.Parallel() - setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uint16, uuid.UUID) { + testPort := codersdk.WorkspaceAgentListeningPort{ + Network: "tcp", + ProcessName: "test-app", + Port: 44762, + } + filteredPort := codersdk.WorkspaceAgentListeningPort{ + Network: "tcp", + ProcessName: "postgres", + Port: 5432, + } + + setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uuid.UUID, *fakeListeningPortsGetter) { t.Helper() client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ DeploymentValues: dv, }) - coderdPort, err := strconv.Atoi(client.URL.Port()) - require.NoError(t, err) + + fLPG := &fakeListeningPortsGetter{} user := coderdtest.CreateFirstUser(t, client) r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -955,228 +981,73 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) { return agents }).Do() _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.PortCacheDuration = time.Millisecond + o.ListeningPortsGetter = fLPG }) - resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) + resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) - return client, uint16(coderdPort), resources[0].Agents[0].ID + return client, resources[0].Agents[0].ID, fLPG } - willFilterPort := func(port int) bool { - if port < workspacesdk.AgentMinimumListeningPort || port > 65535 { - return true - } - if _, ok := workspacesdk.AgentIgnoredListeningPorts[uint16(port)]; ok { - return true - } - - return false - } - - generateUnfilteredPort := func(t *testing.T) (net.Listener, uint16) { - var ( - l net.Listener - port uint16 - ) - require.Eventually(t, func() bool { - var err error - l, err = net.Listen("tcp", "localhost:0") - if err != nil { - return false - } - tcpAddr, _ := l.Addr().(*net.TCPAddr) - if willFilterPort(tcpAddr.Port) { - _ = l.Close() - return false - } - t.Cleanup(func() { - _ = l.Close() - }) - - // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) - port = uint16(tcpAddr.Port) - return true - }, testutil.WaitShort, testutil.IntervalFast) - - return l, port - } - - generateFilteredPort := func(t *testing.T) (net.Listener, uint16) { - var ( - l net.Listener - port uint16 - ) - require.Eventually(t, func() bool { - for ignoredPort := range workspacesdk.AgentIgnoredListeningPorts { - if ignoredPort < 1024 || ignoredPort == 5432 { - continue - } - - var err error - l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", ignoredPort)) - if err != nil { - continue - } - t.Cleanup(func() { - _ = l.Close() - }) - - port = ignoredPort - return true - } - - return false - }, testutil.WaitShort, testutil.IntervalFast) - - return l, port - } - - t.Run("LinuxAndWindows", func(t *testing.T) { - t.Parallel() - if runtime.GOOS != "linux" && runtime.GOOS != "windows" { - t.Skip("only runs on linux and windows") - return - } - - for _, tc := range []struct { - name string - setDV func(t *testing.T, dv *codersdk.DeploymentValues) - }{ - { - name: "Mainline", - setDV: func(*testing.T, *codersdk.DeploymentValues) {}, - }, - { - name: "BlockDirect", - setDV: func(t *testing.T, dv *codersdk.DeploymentValues) { - err := dv.DERP.Config.BlockDirect.Set("true") - require.NoError(t, err) - require.True(t, dv.DERP.Config.BlockDirect.Value()) - }, - }, - } { - t.Run("OK_"+tc.name, func(t *testing.T) { - t.Parallel() - - dv := coderdtest.DeploymentValues(t) - tc.setDV(t, dv) - client, coderdPort, agentID := setup(t, nil, dv) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - // Generate a random unfiltered port. - l, lPort := generateUnfilteredPort(t) - - // List ports and ensure that the port we expect to see is there. - res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) + for _, tc := range []struct { + name string + setDV func(t *testing.T, dv *codersdk.DeploymentValues) + }{ + { + name: "Mainline", + setDV: func(*testing.T, *codersdk.DeploymentValues) {}, + }, + { + name: "BlockDirect", + setDV: func(t *testing.T, dv *codersdk.DeploymentValues) { + err := dv.DERP.Config.BlockDirect.Set("true") require.NoError(t, err) - - expected := map[uint16]bool{ - // expect the listener we made - lPort: false, - // expect the coderdtest server - coderdPort: false, - } - for _, port := range res.Ports { - if port.Network == "tcp" { - if val, ok := expected[port.Port]; ok { - if val { - t.Fatalf("expected to find TCP port %d only once in response", port.Port) - } - } - expected[port.Port] = true - } - } - for port, found := range expected { - if !found { - t.Fatalf("expected to find TCP port %d in response", port) - } - } - - // Close the listener and check that the port is no longer in the response. - require.NoError(t, l.Close()) - t.Log("checking for ports after listener close:") - require.Eventually(t, func() bool { - res, err = client.WorkspaceAgentListeningPorts(ctx, agentID) - if !assert.NoError(t, err) { - return false - } - - for _, port := range res.Ports { - if port.Network == "tcp" && port.Port == lPort { - t.Logf("expected to not find TCP port %d in response", lPort) - return false - } - } - return true - }, testutil.WaitLong, testutil.IntervalMedium) - }) - } - - t.Run("Filter", func(t *testing.T) { + require.True(t, dv.DERP.Config.BlockDirect.Value()) + }, + }, + } { + t.Run("OK_"+tc.name, func(t *testing.T) { t.Parallel() - // Generate an unfiltered port that we will create an app for and - // should not exist in the response. - _, appLPort := generateUnfilteredPort(t) - app := &proto.App{ - Slug: "test-app", - Url: fmt.Sprintf("http://localhost:%d", appLPort), - } - - // Generate a filtered port that should not exist in the response. - _, filteredLPort := generateFilteredPort(t) - - client, coderdPort, agentID := setup(t, []*proto.App{app}, nil) + dv := coderdtest.DeploymentValues(t) + tc.setDV(t, dv) + client, agentID, fLPG := setup(t, nil, dv) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + fLPG.setPorts(testPort) + + // List ports and ensure that the port we expect to see is there. res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) require.NoError(t, err) + require.Equal(t, []codersdk.WorkspaceAgentListeningPort{testPort}, res.Ports) - sawCoderdPort := false - for _, port := range res.Ports { - if port.Network == "tcp" { - if port.Port == appLPort { - t.Fatalf("expected to not find TCP port (app port) %d in response", appLPort) - } - if port.Port == filteredLPort { - t.Fatalf("expected to not find TCP port (filtered port) %d in response", filteredLPort) - } - if port.Port == coderdPort { - sawCoderdPort = true - } - } - } - if !sawCoderdPort { - t.Fatalf("expected to find TCP port (coderd port) %d in response", coderdPort) - } + // Remove the port and check that the port is no longer in the response. + fLPG.setPorts() + res, err = client.WorkspaceAgentListeningPorts(ctx, agentID) + require.NoError(t, err) + require.Empty(t, res.Ports) }) - }) + } - t.Run("Darwin", func(t *testing.T) { + t.Run("Filter", func(t *testing.T) { t.Parallel() - if runtime.GOOS != "darwin" { - t.Skip("only runs on darwin") - return + + app := &proto.App{ + Slug: testPort.ProcessName, + Url: fmt.Sprintf("http://localhost:%d", testPort.Port), } - client, _, agentID := setup(t, nil, nil) + client, agentID, fLPG := setup(t, []*proto.App{app}, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - // Create a TCP listener on a random port. - l, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - defer l.Close() + fLPG.setPorts(testPort, filteredPort) - // List ports and ensure that the list is empty because we're on darwin. res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) require.NoError(t, err) - require.Len(t, res.Ports, 0) + require.Empty(t, res.Ports) }) }