diff --git a/agent/agent.go b/agent/agent.go index ab882a80ef..238aee9ad1 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -8,6 +8,7 @@ import ( "fmt" "hash/fnv" "io" + "maps" "net" "net/http" "net/netip" @@ -70,16 +71,21 @@ const ( ) type Options struct { - Filesystem afero.Fs - LogDir string - TempDir string - ScriptDataDir string - Client Client - ReconnectingPTYTimeout time.Duration - EnvironmentVariables map[string]string - Logger slog.Logger - IgnorePorts map[int]string - PortCacheDuration time.Duration + Filesystem afero.Fs + LogDir string + TempDir string + ScriptDataDir string + Client Client + ReconnectingPTYTimeout time.Duration + EnvironmentVariables map[string]string + Logger slog.Logger + // IgnorePorts tells the api handler which ports to ignore when + // listing all listening ports. This is helpful to hide ports that + // are used by the agent, that the user does not care about. + IgnorePorts map[int]string + // ListeningPortsGetter is used to get the list of listening ports. Only + // tests should set this. If unset, a default that queries the OS will be used. + ListeningPortsGetter ListeningPortsGetter SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystems []codersdk.AgentSubsystem @@ -137,9 +143,7 @@ func New(options Options) Agent { if options.ServiceBannerRefreshInterval == 0 { options.ServiceBannerRefreshInterval = 2 * time.Minute } - if options.PortCacheDuration == 0 { - options.PortCacheDuration = 1 * time.Second - } + if options.Clock == nil { options.Clock = quartz.NewReal() } @@ -153,30 +157,38 @@ func New(options Options) Agent { options.Execer = agentexec.DefaultExecer } + if options.ListeningPortsGetter == nil { + options.ListeningPortsGetter = &osListeningPortsGetter{ + cacheDuration: 1 * time.Second, + } + } + hardCtx, hardCancel := context.WithCancel(context.Background()) gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ - clock: options.Clock, - tailnetListenPort: options.TailnetListenPort, - reconnectingPTYTimeout: options.ReconnectingPTYTimeout, - logger: options.Logger, - gracefulCtx: gracefulCtx, - gracefulCancel: gracefulCancel, - hardCtx: hardCtx, - hardCancel: hardCancel, - coordDisconnected: make(chan struct{}), - environmentVariables: options.EnvironmentVariables, - client: options.Client, - filesystem: options.Filesystem, - logDir: options.LogDir, - tempDir: options.TempDir, - scriptDataDir: options.ScriptDataDir, - lifecycleUpdate: make(chan struct{}, 1), - lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), - lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, - reportConnectionsUpdate: make(chan struct{}, 1), - ignorePorts: options.IgnorePorts, - portCacheDuration: options.PortCacheDuration, + clock: options.Clock, + tailnetListenPort: options.TailnetListenPort, + reconnectingPTYTimeout: options.ReconnectingPTYTimeout, + logger: options.Logger, + gracefulCtx: gracefulCtx, + gracefulCancel: gracefulCancel, + hardCtx: hardCtx, + hardCancel: hardCancel, + coordDisconnected: make(chan struct{}), + environmentVariables: options.EnvironmentVariables, + client: options.Client, + filesystem: options.Filesystem, + logDir: options.LogDir, + tempDir: options.TempDir, + scriptDataDir: options.ScriptDataDir, + lifecycleUpdate: make(chan struct{}, 1), + lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), + lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, + reportConnectionsUpdate: make(chan struct{}, 1), + listeningPortsHandler: listeningPortsHandler{ + getter: options.ListeningPortsGetter, + ignorePorts: maps.Clone(options.IgnorePorts), + }, reportMetadataInterval: options.ReportMetadataInterval, announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval, sshMaxTimeout: options.SSHMaxTimeout, @@ -202,20 +214,16 @@ func New(options Options) Agent { } type agent struct { - clock quartz.Clock - logger slog.Logger - client Client - tailnetListenPort uint16 - filesystem afero.Fs - logDir string - tempDir string - scriptDataDir string - // ignorePorts tells the api handler which ports to ignore when - // listing all listening ports. This is helpful to hide ports that - // are used by the agent, that the user does not care about. - ignorePorts map[int]string - portCacheDuration time.Duration - subsystems []codersdk.AgentSubsystem + clock quartz.Clock + logger slog.Logger + client Client + tailnetListenPort uint16 + filesystem afero.Fs + logDir string + tempDir string + scriptDataDir string + listeningPortsHandler listeningPortsHandler + subsystems []codersdk.AgentSubsystem reconnectingPTYTimeout time.Duration reconnectingPTYServer *reconnectingpty.Server diff --git a/agent/api.go b/agent/api.go index f417a046c2..4e1da8b028 100644 --- a/agent/api.go +++ b/agent/api.go @@ -2,14 +2,13 @@ package agent import ( "net/http" - "sync" - "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" ) func (a *agent) apiHandler() http.Handler { @@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler { }) }) - // Make a copy to ensure the map is not modified after the handler is - // created. - cpy := make(map[int]string) - for k, b := range a.ignorePorts { - cpy[k] = b - } - - cacheDuration := 1 * time.Second - if a.portCacheDuration > 0 { - cacheDuration = a.portCacheDuration - } - - lp := &listeningPortsHandler{ - ignorePorts: cpy, - cacheDuration: cacheDuration, - } - if a.devcontainers { r.Mount("/api/v0/containers", a.containerAPI.Routes()) } else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil { @@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler { promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger) - r.Get("/api/v0/listening-ports", lp.handler) + r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler) r.Get("/api/v0/netcheck", a.HandleNetcheck) r.Post("/api/v0/list-directory", a.HandleLS) r.Get("/api/v0/read-file", a.HandleReadFile) @@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler { return r } -type listeningPortsHandler struct { - ignorePorts map[int]string - cacheDuration time.Duration +type ListeningPortsGetter interface { + GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) +} - //nolint: unused // used on some but not all platforms - mut sync.Mutex - //nolint: unused // used on some but not all platforms - ports []codersdk.WorkspaceAgentListeningPort - //nolint: unused // used on some but not all platforms - mtime time.Time +type listeningPortsHandler struct { + // In production code, this is set to an osListeningPortsGetter, but it can be overridden for + // testing. + getter ListeningPortsGetter + ignorePorts map[int]string } // handler returns a list of listening ports. This is tested by coderd's // TestWorkspaceAgentListeningPorts test. func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) { - ports, err := lp.getListeningPorts() + ports, err := lp.getter.GetListeningPorts() if err != nil { httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ Message: "Could not scan for listening ports.", @@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request return } + filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports)) + for _, port := range ports { + if port.Port < workspacesdk.AgentMinimumListeningPort { + continue + } + + // Ignore ports that we've been told to ignore. + if _, ok := lp.ignorePorts[int(port.Port)]; ok { + continue + } + filteredPorts = append(filteredPorts, port) + } + httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{ - Ports: ports, + Ports: filteredPorts, }) } diff --git a/agent/ports_supported.go b/agent/ports_supported.go index efa554de98..30df6caf7a 100644 --- a/agent/ports_supported.go +++ b/agent/ports_supported.go @@ -3,16 +3,23 @@ package agent import ( + "sync" "time" "github.com/cakturk/go-netstat/netstat" "golang.org/x/xerrors" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/workspacesdk" ) -func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { +type osListeningPortsGetter struct { + cacheDuration time.Duration + mut sync.Mutex + ports []codersdk.WorkspaceAgentListeningPort + mtime time.Time +} + +func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { lp.mut.Lock() defer lp.mut.Unlock() @@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL seen := make(map[uint16]struct{}, len(tabs)) ports := []codersdk.WorkspaceAgentListeningPort{} for _, tab := range tabs { - if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort { - continue - } - - // Ignore ports that we've been told to ignore. - if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok { + if tab.LocalAddr == nil { continue } diff --git a/agent/ports_supported_internal_test.go b/agent/ports_supported_internal_test.go new file mode 100644 index 0000000000..e16bd8a0c8 --- /dev/null +++ b/agent/ports_supported_internal_test.go @@ -0,0 +1,45 @@ +//go:build linux || (windows && amd64) + +package agent + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOSListeningPortsGetter(t *testing.T) { + t.Parallel() + + uut := &osListeningPortsGetter{ + cacheDuration: 1 * time.Hour, + } + + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer l.Close() + + ports, err := uut.GetListeningPorts() + require.NoError(t, err) + found := false + for _, port := range ports { + // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) + if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) { + found = true + break + } + } + require.True(t, found) + + // check that we cache the ports + err = l.Close() + require.NoError(t, err) + portsNew, err := uut.GetListeningPorts() + require.NoError(t, err) + require.Equal(t, ports, portsNew) + + // note that it's unsafe to try to assert that a port does not exist in the response + // because the OS may reallocate the port very quickly. +} diff --git a/agent/ports_unsupported.go b/agent/ports_unsupported.go index 89ca4f1755..661956a3fc 100644 --- a/agent/ports_unsupported.go +++ b/agent/ports_unsupported.go @@ -2,9 +2,17 @@ package agent -import "github.com/coder/coder/v2/codersdk" +import ( + "time" -func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { + "github.com/coder/coder/v2/codersdk" +) + +type osListeningPortsGetter struct { + cacheDuration time.Duration +} + +func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { // Can't scan for ports on non-linux or non-windows_amd64 systems at the // moment. The UI will not show any "no ports found" message to the user, so // the user won't suspect a thing. diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index e950f97075..6c12f91d37 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -5,12 +5,10 @@ import ( "encoding/json" "fmt" "maps" - "net" "net/http" "os" "path/filepath" - "runtime" - "strconv" + "slices" "strings" "sync" "sync/atomic" @@ -934,17 +932,45 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { require.False(t, p2p) } +type fakeListeningPortsGetter struct { + sync.Mutex + ports []codersdk.WorkspaceAgentListeningPort +} + +func (g *fakeListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) { + g.Lock() + defer g.Unlock() + return slices.Clone(g.ports), nil +} + +func (g *fakeListeningPortsGetter) setPorts(ports ...codersdk.WorkspaceAgentListeningPort) { + g.Lock() + defer g.Unlock() + g.ports = slices.Clone(ports) +} + func TestWorkspaceAgentListeningPorts(t *testing.T) { t.Parallel() - setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uint16, uuid.UUID) { + testPort := codersdk.WorkspaceAgentListeningPort{ + Network: "tcp", + ProcessName: "test-app", + Port: 44762, + } + filteredPort := codersdk.WorkspaceAgentListeningPort{ + Network: "tcp", + ProcessName: "postgres", + Port: 5432, + } + + setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uuid.UUID, *fakeListeningPortsGetter) { t.Helper() client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ DeploymentValues: dv, }) - coderdPort, err := strconv.Atoi(client.URL.Port()) - require.NoError(t, err) + + fLPG := &fakeListeningPortsGetter{} user := coderdtest.CreateFirstUser(t, client) r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ @@ -955,228 +981,73 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) { return agents }).Do() _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.PortCacheDuration = time.Millisecond + o.ListeningPortsGetter = fLPG }) - resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) + resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) - return client, uint16(coderdPort), resources[0].Agents[0].ID + return client, resources[0].Agents[0].ID, fLPG } - willFilterPort := func(port int) bool { - if port < workspacesdk.AgentMinimumListeningPort || port > 65535 { - return true - } - if _, ok := workspacesdk.AgentIgnoredListeningPorts[uint16(port)]; ok { - return true - } - - return false - } - - generateUnfilteredPort := func(t *testing.T) (net.Listener, uint16) { - var ( - l net.Listener - port uint16 - ) - require.Eventually(t, func() bool { - var err error - l, err = net.Listen("tcp", "localhost:0") - if err != nil { - return false - } - tcpAddr, _ := l.Addr().(*net.TCPAddr) - if willFilterPort(tcpAddr.Port) { - _ = l.Close() - return false - } - t.Cleanup(func() { - _ = l.Close() - }) - - // #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535) - port = uint16(tcpAddr.Port) - return true - }, testutil.WaitShort, testutil.IntervalFast) - - return l, port - } - - generateFilteredPort := func(t *testing.T) (net.Listener, uint16) { - var ( - l net.Listener - port uint16 - ) - require.Eventually(t, func() bool { - for ignoredPort := range workspacesdk.AgentIgnoredListeningPorts { - if ignoredPort < 1024 || ignoredPort == 5432 { - continue - } - - var err error - l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", ignoredPort)) - if err != nil { - continue - } - t.Cleanup(func() { - _ = l.Close() - }) - - port = ignoredPort - return true - } - - return false - }, testutil.WaitShort, testutil.IntervalFast) - - return l, port - } - - t.Run("LinuxAndWindows", func(t *testing.T) { - t.Parallel() - if runtime.GOOS != "linux" && runtime.GOOS != "windows" { - t.Skip("only runs on linux and windows") - return - } - - for _, tc := range []struct { - name string - setDV func(t *testing.T, dv *codersdk.DeploymentValues) - }{ - { - name: "Mainline", - setDV: func(*testing.T, *codersdk.DeploymentValues) {}, - }, - { - name: "BlockDirect", - setDV: func(t *testing.T, dv *codersdk.DeploymentValues) { - err := dv.DERP.Config.BlockDirect.Set("true") - require.NoError(t, err) - require.True(t, dv.DERP.Config.BlockDirect.Value()) - }, - }, - } { - t.Run("OK_"+tc.name, func(t *testing.T) { - t.Parallel() - - dv := coderdtest.DeploymentValues(t) - tc.setDV(t, dv) - client, coderdPort, agentID := setup(t, nil, dv) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - // Generate a random unfiltered port. - l, lPort := generateUnfilteredPort(t) - - // List ports and ensure that the port we expect to see is there. - res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) + for _, tc := range []struct { + name string + setDV func(t *testing.T, dv *codersdk.DeploymentValues) + }{ + { + name: "Mainline", + setDV: func(*testing.T, *codersdk.DeploymentValues) {}, + }, + { + name: "BlockDirect", + setDV: func(t *testing.T, dv *codersdk.DeploymentValues) { + err := dv.DERP.Config.BlockDirect.Set("true") require.NoError(t, err) - - expected := map[uint16]bool{ - // expect the listener we made - lPort: false, - // expect the coderdtest server - coderdPort: false, - } - for _, port := range res.Ports { - if port.Network == "tcp" { - if val, ok := expected[port.Port]; ok { - if val { - t.Fatalf("expected to find TCP port %d only once in response", port.Port) - } - } - expected[port.Port] = true - } - } - for port, found := range expected { - if !found { - t.Fatalf("expected to find TCP port %d in response", port) - } - } - - // Close the listener and check that the port is no longer in the response. - require.NoError(t, l.Close()) - t.Log("checking for ports after listener close:") - require.Eventually(t, func() bool { - res, err = client.WorkspaceAgentListeningPorts(ctx, agentID) - if !assert.NoError(t, err) { - return false - } - - for _, port := range res.Ports { - if port.Network == "tcp" && port.Port == lPort { - t.Logf("expected to not find TCP port %d in response", lPort) - return false - } - } - return true - }, testutil.WaitLong, testutil.IntervalMedium) - }) - } - - t.Run("Filter", func(t *testing.T) { + require.True(t, dv.DERP.Config.BlockDirect.Value()) + }, + }, + } { + t.Run("OK_"+tc.name, func(t *testing.T) { t.Parallel() - // Generate an unfiltered port that we will create an app for and - // should not exist in the response. - _, appLPort := generateUnfilteredPort(t) - app := &proto.App{ - Slug: "test-app", - Url: fmt.Sprintf("http://localhost:%d", appLPort), - } - - // Generate a filtered port that should not exist in the response. - _, filteredLPort := generateFilteredPort(t) - - client, coderdPort, agentID := setup(t, []*proto.App{app}, nil) + dv := coderdtest.DeploymentValues(t) + tc.setDV(t, dv) + client, agentID, fLPG := setup(t, nil, dv) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() + fLPG.setPorts(testPort) + + // List ports and ensure that the port we expect to see is there. res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) require.NoError(t, err) + require.Equal(t, []codersdk.WorkspaceAgentListeningPort{testPort}, res.Ports) - sawCoderdPort := false - for _, port := range res.Ports { - if port.Network == "tcp" { - if port.Port == appLPort { - t.Fatalf("expected to not find TCP port (app port) %d in response", appLPort) - } - if port.Port == filteredLPort { - t.Fatalf("expected to not find TCP port (filtered port) %d in response", filteredLPort) - } - if port.Port == coderdPort { - sawCoderdPort = true - } - } - } - if !sawCoderdPort { - t.Fatalf("expected to find TCP port (coderd port) %d in response", coderdPort) - } + // Remove the port and check that the port is no longer in the response. + fLPG.setPorts() + res, err = client.WorkspaceAgentListeningPorts(ctx, agentID) + require.NoError(t, err) + require.Empty(t, res.Ports) }) - }) + } - t.Run("Darwin", func(t *testing.T) { + t.Run("Filter", func(t *testing.T) { t.Parallel() - if runtime.GOOS != "darwin" { - t.Skip("only runs on darwin") - return + + app := &proto.App{ + Slug: testPort.ProcessName, + Url: fmt.Sprintf("http://localhost:%d", testPort.Port), } - client, _, agentID := setup(t, nil, nil) + client, agentID, fLPG := setup(t, []*proto.App{app}, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - // Create a TCP listener on a random port. - l, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - defer l.Close() + fLPG.setPorts(testPort, filteredPort) - // List ports and ensure that the list is empty because we're on darwin. res, err := client.WorkspaceAgentListeningPorts(ctx, agentID) require.NoError(t, err) - require.Len(t, res.Ports, 0) + require.Empty(t, res.Ports) }) }