From 181e103201c8adab144c0edb46d37a503277f95f Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Tue, 21 Apr 2026 11:37:10 +1000 Subject: [PATCH] fix: reuse shared tailnet for coderd-hosted MCP workspace tools (#24460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem Coderd can expose an MCP server at `/api/experimental/mcp/http` (we have this enabled on dogfood). Its workspace tools dialed agents through a per-call client-side tailnet stack. Every tool call re-created a WireGuard device, netstack, magicsock + UDP sockets, DERP connection, coordinator websocket, and their goroutines — in a process that already runs a long-lived shared tailnet. The duplicate stacks drove up resource usage under load. ## Fix Route this server's tool calls through the existing shared tailnet, so none of those transports are reconstructed per call. Closing an `AgentConn` now releases a tunnel reference instead of tearing down a transport. ## Potential follow-up `coder exp mcp server` still builds a fresh tailnet per call. It pays per-call latency and causes coordinator/DERP churn. A shared CLI tailnet is more involved — unlike coderd, the CLI has no existing shared tailnet to reuse, so it would need a new long-lived client-side tailnet with reconnect, sleep/wake, and idle-destination handling. There's less motivation to optimize this, given the client-side MCP does not compete for resources with coderd. Closes CODAGT-199 > Generated by mux, but reviewed by a human --- coderd/mcp/mcp.go | 8 +- coderd/mcp/mcp_e2e_test.go | 96 +++++++++----------- coderd/mcp_http.go | 6 +- codersdk/toolsdk/bash.go | 2 +- codersdk/toolsdk/toolsdk.go | 105 +++++++++++++-------- codersdk/toolsdk/toolsdk_test.go | 126 ++++++++++++++++++++++++++ codersdk/workspacesdk/agentconn.go | 35 +++++++ codersdk/workspacesdk/workspacesdk.go | 4 + 8 files changed, 283 insertions(+), 99 deletions(-) diff --git a/coderd/mcp/mcp.go b/coderd/mcp/mcp.go index 3ce17867c4..59cd6566f1 100644 --- a/coderd/mcp/mcp.go +++ b/coderd/mcp/mcp.go @@ -72,13 +72,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Register all available MCP tools with the server excluding: // - ReportTask - which requires dependencies not available in the remote MCP context // - ChatGPT search and fetch tools, which are redundant with the standard tools. -func (s *Server) RegisterTools(client *codersdk.Client) error { +func (s *Server) RegisterTools(client *codersdk.Client, opts ...func(*toolsdk.Deps)) error { if client == nil { return xerrors.New("client cannot be nil: MCP HTTP server requires authenticated client") } // Create tool dependencies - toolDeps, err := toolsdk.NewDeps(client) + toolDeps, err := toolsdk.NewDeps(client, opts...) if err != nil { return xerrors.Errorf("failed to initialize tool dependencies: %w", err) } @@ -100,13 +100,13 @@ func (s *Server) RegisterTools(client *codersdk.Client) error { // We do not expose any extra ones because ChatGPT has an undocumented "Safety Scan" feature. // In my experiments, if I included extra tools in the MCP server, ChatGPT would often - but not always - // refuse to add Coder as a connector. -func (s *Server) RegisterChatGPTTools(client *codersdk.Client) error { +func (s *Server) RegisterChatGPTTools(client *codersdk.Client, opts ...func(*toolsdk.Deps)) error { if client == nil { return xerrors.New("client cannot be nil: MCP HTTP server requires authenticated client") } // Create tool dependencies - toolDeps, err := toolsdk.NewDeps(client) + toolDeps, err := toolsdk.NewDeps(client, opts...) if err != nil { return xerrors.Errorf("failed to initialize tool dependencies: %w", err) } diff --git a/coderd/mcp/mcp_e2e_test.go b/coderd/mcp/mcp_e2e_test.go index b713fd8155..c7dc888000 100644 --- a/coderd/mcp/mcp_e2e_test.go +++ b/coderd/mcp/mcp_e2e_test.go @@ -9,6 +9,8 @@ import ( "io" "net/http" "net/url" + "os" + "path/filepath" "strings" "testing" @@ -16,11 +18,16 @@ import ( mcpclient "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" mcpserver "github.com/coder/coder/v2/coderd/mcp" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/toolsdk" @@ -215,21 +222,27 @@ func TestMCPHTTP_E2E_UnauthenticatedAccess(t *testing.T) { func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) { t.Parallel() - // Setup Coder server with full workspace environment - coderClient, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - IncludeProvisionerDaemon: true, - }) + coderClient, closer, api := coderdtest.NewWithAPI(t, nil) defer closer.Close() user := coderdtest.CreateFirstUser(t, coderClient) + r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{ + Name: "myworkspace", + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() - // Create template and workspace for testing - version := coderdtest.CreateTemplateVersion(t, coderClient, user.OrganizationID, nil) - coderdtest.AwaitTemplateVersionJobCompleted(t, coderClient, version.ID) - template := coderdtest.CreateTemplate(t, coderClient, user.OrganizationID, version.ID) - workspace := coderdtest.CreateWorkspace(t, coderClient, template.ID) + fs := afero.NewMemMapFs() + tmpdir := os.TempDir() + require.NoError(t, fs.MkdirAll(tmpdir, 0o755)) + filePath := filepath.Join(tmpdir, "mcp-http-test.txt") + require.NoError(t, afero.WriteFile(fs, filePath, []byte("hello from mcp"), 0o644)) + + _ = agenttest.New(t, coderClient.URL, r.AgentToken, func(opts *agent.Options) { + opts.Filesystem = fs + }) + coderdtest.NewWorkspaceAgentWaiter(t, coderClient, r.Workspace.ID).Wait() - // Create MCP client mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL, transport.WithHTTPHeaders(map[string]string{ @@ -245,11 +258,8 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - // Start and initialize client - err = mcpClient.Start(ctx) - require.NoError(t, err) - - initReq := mcp.InitializeRequest{ + require.NoError(t, mcpClient.Start(ctx)) + _, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{ Params: mcp.InitializeParams{ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, ClientInfo: mcp.Implementation{ @@ -257,48 +267,30 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) { Version: "1.0.0", }, }, - } - - _, err = mcpClient.Initialize(ctx, initReq) + }) require.NoError(t, err) - // Test workspace-related tools - tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) - require.NoError(t, err) - - // Find workspace listing tool - var workspaceTool *mcp.Tool - for _, tool := range tools.Tools { - if tool.Name == toolsdk.ToolNameListWorkspaces { - workspaceTool = &tool - break - } - } - - if workspaceTool != nil { - // Execute workspace listing tool - toolReq := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: workspaceTool.Name, - Arguments: map[string]any{}, + toolResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: toolsdk.ToolNameWorkspaceLS, + Arguments: map[string]any{ + "workspace": r.Workspace.Name, + "path": tmpdir, }, - } + }, + }) + require.NoError(t, err) + require.NotEmpty(t, toolResult.Content) - toolResult, err := mcpClient.CallTool(ctx, toolReq) - require.NoError(t, err) - require.NotEmpty(t, toolResult.Content) + textContent, ok := toolResult.Content[0].(mcp.TextContent) + require.True(t, ok, "expected TextContent type, got %T", toolResult.Content[0]) - // Verify the result mentions our workspace - if textContent, ok := toolResult.Content[0].(mcp.TextContent); ok { - assert.Contains(t, textContent.Text, workspace.Name, "Workspace listing should include our test workspace") - } else { - t.Error("Expected TextContent type from workspace tool") - } - - t.Logf("Workspace tool test successful: Found workspace %s in results", workspace.Name) - } else { - t.Skip("Workspace listing tool not available, skipping workspace-specific test") - } + var response toolsdk.WorkspaceLSResponse + require.NoError(t, json.Unmarshal([]byte(textContent.Text), &response)) + assert.Contains(t, response.Contents, toolsdk.WorkspaceLSFile{ + Path: filePath, + IsDir: false, + }) } func TestMCPHTTP_E2E_ErrorHandling(t *testing.T) { diff --git a/coderd/mcp_http.go b/coderd/mcp_http.go index 859222b400..dd502a432a 100644 --- a/coderd/mcp_http.go +++ b/coderd/mcp_http.go @@ -9,6 +9,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/mcp" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/toolsdk" ) type MCPToolset string @@ -34,6 +35,7 @@ func (api *API) mcpHTTPHandler() http.Handler { // Extract the original session token from the request authenticatedClient := codersdk.New(api.AccessURL, codersdk.WithSessionToken(httpmw.APITokenFromRequest(r))) + toolOpt := toolsdk.WithAgentConnFunc(api.agentProvider.AgentConn) toolset := MCPToolset(r.URL.Query().Get("toolset")) // Default to standard toolset if no toolset is specified. if toolset == "" { @@ -42,11 +44,11 @@ func (api *API) mcpHTTPHandler() http.Handler { switch toolset { case MCPToolsetStandard: - if err := mcpServer.RegisterTools(authenticatedClient); err != nil { + if err := mcpServer.RegisterTools(authenticatedClient, toolOpt); err != nil { api.Logger.Warn(r.Context(), "failed to register MCP tools", slog.Error(err)) } case MCPToolsetChatGPT: - if err := mcpServer.RegisterChatGPTTools(authenticatedClient); err != nil { + if err := mcpServer.RegisterChatGPTTools(authenticatedClient, toolOpt); err != nil { api.Logger.Warn(r.Context(), "failed to register MCP tools", slog.Error(err)) } default: diff --git a/codersdk/toolsdk/bash.go b/codersdk/toolsdk/bash.go index 78a102fbc1..03f0905b61 100644 --- a/codersdk/toolsdk/bash.go +++ b/codersdk/toolsdk/bash.go @@ -101,7 +101,7 @@ Examples: ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min")) defer cancel() - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return WorkspaceBashResult{}, err } diff --git a/codersdk/toolsdk/toolsdk.go b/codersdk/toolsdk/toolsdk.go index a0db2d98bf..b6d1141a14 100644 --- a/codersdk/toolsdk/toolsdk.go +++ b/codersdk/toolsdk/toolsdk.go @@ -65,6 +65,16 @@ func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { for _, opt := range opts { opt(&d) } + if d.agentConnFn == nil && d.coderClient != nil { + workspaceClient := workspacesdk.New(d.coderClient) + d.agentConnFn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + conn, err := workspaceClient.DialAgent(ctx, agentID, nil) + if err != nil { + return nil, nil, err + } + return conn, nil, nil + } + } // Allow nil client for unauthenticated operation // This enables tools that don't require user authentication to function return d, nil @@ -74,6 +84,7 @@ func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) { type Deps struct { coderClient *codersdk.Client report func(ReportTaskArgs) error + agentConnFn workspacesdk.AgentConnFunc } func (d Deps) ServerURL() string { @@ -89,6 +100,55 @@ func WithTaskReporter(fn func(ReportTaskArgs) error) func(*Deps) { } } +// WithAgentConnFunc overrides how workspace tools open logical connections to +// workspace agents. +func WithAgentConnFunc(agentConnFn workspacesdk.AgentConnFunc) func(*Deps) { + return func(d *Deps) { + d.agentConnFn = agentConnFn + } +} + +// openAgentConn opens a ready workspace agent session for workspace inputs in +// [owner/]workspace[.agent] format. +func openAgentConn(ctx context.Context, deps Deps, workspace string) (workspacesdk.AgentConn, error) { + if deps.coderClient == nil { + return nil, xerrors.New("workspace tools require an authenticated client") + } + + workspaceName := NormalizeWorkspaceInput(workspace) + _, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName) + if err != nil { + return nil, xerrors.Errorf("failed to find workspace: %w", err) + } + + if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ + FetchInterval: 0, + Fetch: deps.coderClient.WorkspaceAgent, + FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter, + // Always wait for startup scripts. + Wait: true, + }); err != nil { + return nil, xerrors.Errorf("agent not ready: %w", err) + } + + conn, release, err := deps.agentConnFn(ctx, workspaceAgent.ID) + if err != nil { + return nil, xerrors.Errorf("failed to dial agent: %w", err) + } + + wrappedConn := workspacesdk.WrapAgentConn(conn, func() error { + if release != nil { + release() + } + return nil + }) + if wrappedConn == nil { + return nil, xerrors.New("agent connection function returned nil connection") + } + + return wrappedConn, nil +} + // HandlerFunc is a typed function that handles a tool call. type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error) @@ -1501,7 +1561,7 @@ var WorkspaceLS = Tool[WorkspaceLSArgs, WorkspaceLSResponse]{ MCPAnnotations: mcpReadOnlyAnnotations, UserClientOptional: true, Handler: func(ctx context.Context, deps Deps, args WorkspaceLSArgs) (WorkspaceLSResponse, error) { - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return WorkspaceLSResponse{}, err } @@ -1567,7 +1627,7 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{ MCPAnnotations: mcpReadOnlyAnnotations, UserClientOptional: true, Handler: func(ctx context.Context, deps Deps, args WorkspaceReadFileArgs) (WorkspaceReadFileResponse, error) { - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return WorkspaceReadFileResponse{}, err } @@ -1641,7 +1701,7 @@ content you are trying to write, then re-encode it properly. MCPAnnotations: mcpDestructiveAnnotations, UserClientOptional: true, Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) { - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return codersdk.Response{}, err } @@ -1716,7 +1776,7 @@ var WorkspaceEditFile = Tool[WorkspaceEditFileArgs, WorkspaceEditFilesResponse]{ MCPAnnotations: mcpDestructiveAnnotations, UserClientOptional: true, Handler: func(ctx context.Context, deps Deps, args WorkspaceEditFileArgs) (WorkspaceEditFilesResponse, error) { - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return WorkspaceEditFilesResponse{}, err } @@ -1800,7 +1860,7 @@ var WorkspaceEditFiles = Tool[WorkspaceEditFilesArgs, WorkspaceEditFilesResponse MCPAnnotations: mcpDestructiveAnnotations, UserClientOptional: true, Handler: func(ctx context.Context, deps Deps, args WorkspaceEditFilesArgs) (WorkspaceEditFilesResponse, error) { - conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace) + conn, err := openAgentConn(ctx, deps, args.Workspace) if err != nil { return WorkspaceEditFilesResponse{}, err } @@ -2245,41 +2305,6 @@ func NormalizeWorkspaceInput(input string) string { return normalized } -// newAgentConn returns a connection to the agent specified by the workspace, -// which must be in the format [owner/]workspace[.agent]. -func newAgentConn(ctx context.Context, client *codersdk.Client, workspace string) (workspacesdk.AgentConn, error) { - workspaceName := NormalizeWorkspaceInput(workspace) - _, workspaceAgent, err := findWorkspaceAndAgent(ctx, client, workspaceName) - if err != nil { - return nil, xerrors.Errorf("failed to find workspace: %w", err) - } - - // Wait for agent to be ready. - if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{ - FetchInterval: 0, - Fetch: client.WorkspaceAgent, - FetchLogs: client.WorkspaceAgentLogsAfter, - Wait: true, // Always wait for startup scripts - }); err != nil { - return nil, xerrors.Errorf("agent not ready: %w", err) - } - - wsClient := workspacesdk.New(client) - - conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{ - BlockEndpoints: false, - }) - if err != nil { - return nil, xerrors.Errorf("failed to dial agent: %w", err) - } - - if !conn.AwaitReachable(ctx) { - conn.Close() - return nil, xerrors.New("agent connection not reachable") - } - return conn, nil -} - const workspaceDescription = "The workspace ID or name in the format [owner/]workspace. If an owner is not specified, the authenticated user is used." const workspaceAgentDescription = "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used." diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index ec5567c4b2..2e5e005883 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "golang.org/x/xerrors" agentapi "github.com/coder/agentapi-sdk-go" "github.com/coder/aisdk-go" @@ -61,6 +62,22 @@ func setupWorkspaceForAgent(t *testing.T, opts *coderdtest.Options) (*codersdk.C return userClient, r.Workspace, r.AgentToken } +type recordingAgentConnFunc struct { + conn workspacesdk.AgentConn + err error + agentID uuid.UUID + calls int +} + +func (d *recordingAgentConnFunc) AgentConn(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + d.calls++ + d.agentID = agentID + if d.err != nil { + return nil, nil, d.err + } + return d.conn, nil, nil +} + // These tests are dependent on the state of the coder server. // Running them in parallel is prone to racy behavior. // nolint:tparallel,paralleltest @@ -597,6 +614,115 @@ func TestTools(t *testing.T) { }, res.Contents) }) + t.Run("WorkspaceToolsUseInjectedAgentConnFunc", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait() + + ws, err := client.Workspace(t.Context(), workspace.ID) + require.NoError(t, err) + require.NotEmpty(t, ws.LatestBuild.Resources) + require.NotEmpty(t, ws.LatestBuild.Resources[0].Agents) + agentID := ws.LatestBuild.Resources[0].Agents[0].ID + sentinelErr := xerrors.New("injected agent connection function used") + + tests := []struct { + name string + run func(t *testing.T, tb toolsdk.Deps) error + }{ + { + name: "WorkspaceLS", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceLS, tb, toolsdk.WorkspaceLSArgs{ + Workspace: workspace.Name, + Path: "/tmp", + }) + return err + }, + }, + { + name: "WorkspaceReadFile", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceReadFile, tb, toolsdk.WorkspaceReadFileArgs{ + Workspace: workspace.Name, + Path: "/tmp/file", + }) + return err + }, + }, + { + name: "WorkspaceWriteFile", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{ + Workspace: workspace.Name, + Path: "/tmp/file", + Content: []byte("hello from agent connection function"), + }) + return err + }, + }, + { + name: "WorkspaceEditFile", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceEditFile, tb, toolsdk.WorkspaceEditFileArgs{ + Workspace: workspace.Name, + Path: "/tmp/file", + Edits: []workspacesdk.FileEdit{{ + Search: "hello", + Replace: "goodbye", + }}, + }) + return err + }, + }, + { + name: "WorkspaceEditFiles", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceEditFiles, tb, toolsdk.WorkspaceEditFilesArgs{ + Workspace: workspace.Name, + Files: []workspacesdk.FileEdits{{ + Path: "/tmp/file", + Edits: []workspacesdk.FileEdit{{ + Search: "hello", + Replace: "goodbye", + }}, + }}, + }) + return err + }, + }, + { + name: "WorkspaceBash", + run: func(t *testing.T, tb toolsdk.Deps) error { + _, err := testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{ + Workspace: workspace.Name, + Command: "echo hello", + }) + return err + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agentConnFn := &recordingAgentConnFunc{err: sentinelErr} + tb, err := toolsdk.NewDeps(client, toolsdk.WithAgentConnFunc(agentConnFn.AgentConn)) + require.NoError(t, err) + + err = tt.run(t, tb) + require.ErrorIs(t, err, sentinelErr) + require.ErrorContains(t, err, "failed to dial agent") + require.Equal(t, 1, agentConnFn.calls) + require.Equal(t, agentID, agentConnFn.agentID) + }) + } + }) + t.Run("WorkspaceReadFile", func(t *testing.T) { t.Parallel() diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 58abc164d1..6882ff0d91 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "net" @@ -43,6 +44,40 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn { } } +// WrapAgentConn returns an AgentConn that delegates every operation to conn and +// applies closeFunc exactly once when the logical session is closed. +// +// If conn is nil, any provided closeFunc is invoked immediately so logical +// session cleanup is not silently dropped. +func WrapAgentConn(conn AgentConn, closeFunc func() error) AgentConn { + if conn == nil { + if closeFunc != nil { + _ = closeFunc() + } + return nil + } + if closeFunc == nil { + closeFunc = func() error { return nil } + } + return &wrappedAgentConn{AgentConn: conn, closeFunc: closeFunc} +} + +type wrappedAgentConn struct { + AgentConn + closeFunc func() error + closeOnce sync.Once + closeErr error +} + +func (c *wrappedAgentConn) Close() error { + c.closeOnce.Do(func() { + // Close the underlying connection before releasing the logical session so + // the lease remains held until teardown is complete. + c.closeErr = errors.Join(c.AgentConn.Close(), c.closeFunc()) + }) + return c.closeErr +} + const ( // CoderChatIDHeader is the HTTP header containing the current // chat ID. Set by coderd on agentconn requests originating diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 9aa00646fc..67eab8b4bc 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -175,6 +175,10 @@ func (c *Client) AgentConnectionInfo(ctx context.Context, agentID uuid.UUID) (Ag return connInfo, json.NewDecoder(res.Body).Decode(&connInfo) } +// AgentConnFunc returns a new connection to the specified agent. If release is +// non-nil, callers must invoke it after they are done with the AgentConn. +type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (conn AgentConn, release func(), err error) + // @typescript-ignore DialAgentOptions type DialAgentOptions struct { Logger slog.Logger