mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: reuse shared tailnet for coderd-hosted MCP workspace tools (#24460)
## Problem Coderd can expose an MCP server at `/api/experimental/mcp/http` (we have this enabled on dogfood). Its workspace tools dialed agents through a per-call client-side tailnet stack. Every tool call re-created a WireGuard device, netstack, magicsock + UDP sockets, DERP connection, coordinator websocket, and their goroutines — in a process that already runs a long-lived shared tailnet. The duplicate stacks drove up resource usage under load. ## Fix Route this server's tool calls through the existing shared tailnet, so none of those transports are reconstructed per call. Closing an `AgentConn` now releases a tunnel reference instead of tearing down a transport. ## Potential follow-up `coder exp mcp server` still builds a fresh tailnet per call. It pays per-call latency and causes coordinator/DERP churn. A shared CLI tailnet is more involved — unlike coderd, the CLI has no existing shared tailnet to reuse, so it would need a new long-lived client-side tailnet with reconnect, sleep/wake, and idle-destination handling. There's less motivation to optimize this, given the client-side MCP does not compete for resources with coderd. Closes CODAGT-199 > Generated by mux, but reviewed by a human
This commit is contained in:
+4
-4
@@ -72,13 +72,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Register all available MCP tools with the server excluding:
|
||||
// - ReportTask - which requires dependencies not available in the remote MCP context
|
||||
// - ChatGPT search and fetch tools, which are redundant with the standard tools.
|
||||
func (s *Server) RegisterTools(client *codersdk.Client) error {
|
||||
func (s *Server) RegisterTools(client *codersdk.Client, opts ...func(*toolsdk.Deps)) error {
|
||||
if client == nil {
|
||||
return xerrors.New("client cannot be nil: MCP HTTP server requires authenticated client")
|
||||
}
|
||||
|
||||
// Create tool dependencies
|
||||
toolDeps, err := toolsdk.NewDeps(client)
|
||||
toolDeps, err := toolsdk.NewDeps(client, opts...)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
|
||||
}
|
||||
@@ -100,13 +100,13 @@ func (s *Server) RegisterTools(client *codersdk.Client) error {
|
||||
// We do not expose any extra ones because ChatGPT has an undocumented "Safety Scan" feature.
|
||||
// In my experiments, if I included extra tools in the MCP server, ChatGPT would often - but not always -
|
||||
// refuse to add Coder as a connector.
|
||||
func (s *Server) RegisterChatGPTTools(client *codersdk.Client) error {
|
||||
func (s *Server) RegisterChatGPTTools(client *codersdk.Client, opts ...func(*toolsdk.Deps)) error {
|
||||
if client == nil {
|
||||
return xerrors.New("client cannot be nil: MCP HTTP server requires authenticated client")
|
||||
}
|
||||
|
||||
// Create tool dependencies
|
||||
toolDeps, err := toolsdk.NewDeps(client)
|
||||
toolDeps, err := toolsdk.NewDeps(client, opts...)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
|
||||
}
|
||||
|
||||
+44
-52
@@ -9,6 +9,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -16,11 +18,16 @@ import (
|
||||
mcpclient "github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
mcpserver "github.com/coder/coder/v2/coderd/mcp"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
@@ -215,21 +222,27 @@ func TestMCPHTTP_E2E_UnauthenticatedAccess(t *testing.T) {
|
||||
func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup Coder server with full workspace environment
|
||||
coderClient, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
coderClient, closer, api := coderdtest.NewWithAPI(t, nil)
|
||||
defer closer.Close()
|
||||
|
||||
user := coderdtest.CreateFirstUser(t, coderClient)
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
|
||||
Name: "myworkspace",
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
// Create template and workspace for testing
|
||||
version := coderdtest.CreateTemplateVersion(t, coderClient, user.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, coderClient, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, coderClient, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, coderClient, template.ID)
|
||||
fs := afero.NewMemMapFs()
|
||||
tmpdir := os.TempDir()
|
||||
require.NoError(t, fs.MkdirAll(tmpdir, 0o755))
|
||||
filePath := filepath.Join(tmpdir, "mcp-http-test.txt")
|
||||
require.NoError(t, afero.WriteFile(fs, filePath, []byte("hello from mcp"), 0o644))
|
||||
|
||||
_ = agenttest.New(t, coderClient.URL, r.AgentToken, func(opts *agent.Options) {
|
||||
opts.Filesystem = fs
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, coderClient, r.Workspace.ID).Wait()
|
||||
|
||||
// Create MCP client
|
||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
||||
transport.WithHTTPHeaders(map[string]string{
|
||||
@@ -245,11 +258,8 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Start and initialize client
|
||||
err = mcpClient.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
initReq := mcp.InitializeRequest{
|
||||
require.NoError(t, mcpClient.Start(ctx))
|
||||
_, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
@@ -257,48 +267,30 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) {
|
||||
Version: "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = mcpClient.Initialize(ctx, initReq)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test workspace-related tools
|
||||
tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find workspace listing tool
|
||||
var workspaceTool *mcp.Tool
|
||||
for _, tool := range tools.Tools {
|
||||
if tool.Name == toolsdk.ToolNameListWorkspaces {
|
||||
workspaceTool = &tool
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if workspaceTool != nil {
|
||||
// Execute workspace listing tool
|
||||
toolReq := mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: workspaceTool.Name,
|
||||
Arguments: map[string]any{},
|
||||
toolResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{
|
||||
Params: mcp.CallToolParams{
|
||||
Name: toolsdk.ToolNameWorkspaceLS,
|
||||
Arguments: map[string]any{
|
||||
"workspace": r.Workspace.Name,
|
||||
"path": tmpdir,
|
||||
},
|
||||
}
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, toolResult.Content)
|
||||
|
||||
toolResult, err := mcpClient.CallTool(ctx, toolReq)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, toolResult.Content)
|
||||
textContent, ok := toolResult.Content[0].(mcp.TextContent)
|
||||
require.True(t, ok, "expected TextContent type, got %T", toolResult.Content[0])
|
||||
|
||||
// Verify the result mentions our workspace
|
||||
if textContent, ok := toolResult.Content[0].(mcp.TextContent); ok {
|
||||
assert.Contains(t, textContent.Text, workspace.Name, "Workspace listing should include our test workspace")
|
||||
} else {
|
||||
t.Error("Expected TextContent type from workspace tool")
|
||||
}
|
||||
|
||||
t.Logf("Workspace tool test successful: Found workspace %s in results", workspace.Name)
|
||||
} else {
|
||||
t.Skip("Workspace listing tool not available, skipping workspace-specific test")
|
||||
}
|
||||
var response toolsdk.WorkspaceLSResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(textContent.Text), &response))
|
||||
assert.Contains(t, response.Contents, toolsdk.WorkspaceLSFile{
|
||||
Path: filePath,
|
||||
IsDir: false,
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPHTTP_E2E_ErrorHandling(t *testing.T) {
|
||||
|
||||
+4
-2
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/mcp"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
)
|
||||
|
||||
type MCPToolset string
|
||||
@@ -34,6 +35,7 @@ func (api *API) mcpHTTPHandler() http.Handler {
|
||||
// Extract the original session token from the request
|
||||
authenticatedClient := codersdk.New(api.AccessURL,
|
||||
codersdk.WithSessionToken(httpmw.APITokenFromRequest(r)))
|
||||
toolOpt := toolsdk.WithAgentConnFunc(api.agentProvider.AgentConn)
|
||||
toolset := MCPToolset(r.URL.Query().Get("toolset"))
|
||||
// Default to standard toolset if no toolset is specified.
|
||||
if toolset == "" {
|
||||
@@ -42,11 +44,11 @@ func (api *API) mcpHTTPHandler() http.Handler {
|
||||
|
||||
switch toolset {
|
||||
case MCPToolsetStandard:
|
||||
if err := mcpServer.RegisterTools(authenticatedClient); err != nil {
|
||||
if err := mcpServer.RegisterTools(authenticatedClient, toolOpt); err != nil {
|
||||
api.Logger.Warn(r.Context(), "failed to register MCP tools", slog.Error(err))
|
||||
}
|
||||
case MCPToolsetChatGPT:
|
||||
if err := mcpServer.RegisterChatGPTTools(authenticatedClient); err != nil {
|
||||
if err := mcpServer.RegisterChatGPTTools(authenticatedClient, toolOpt); err != nil {
|
||||
api.Logger.Warn(r.Context(), "failed to register MCP tools", slog.Error(err))
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -101,7 +101,7 @@ Examples:
|
||||
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
|
||||
defer cancel()
|
||||
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return WorkspaceBashResult{}, err
|
||||
}
|
||||
|
||||
+65
-40
@@ -65,6 +65,16 @@ func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
|
||||
for _, opt := range opts {
|
||||
opt(&d)
|
||||
}
|
||||
if d.agentConnFn == nil && d.coderClient != nil {
|
||||
workspaceClient := workspacesdk.New(d.coderClient)
|
||||
d.agentConnFn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
conn, err := workspaceClient.DialAgent(ctx, agentID, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, nil, nil
|
||||
}
|
||||
}
|
||||
// Allow nil client for unauthenticated operation
|
||||
// This enables tools that don't require user authentication to function
|
||||
return d, nil
|
||||
@@ -74,6 +84,7 @@ func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
|
||||
type Deps struct {
|
||||
coderClient *codersdk.Client
|
||||
report func(ReportTaskArgs) error
|
||||
agentConnFn workspacesdk.AgentConnFunc
|
||||
}
|
||||
|
||||
func (d Deps) ServerURL() string {
|
||||
@@ -89,6 +100,55 @@ func WithTaskReporter(fn func(ReportTaskArgs) error) func(*Deps) {
|
||||
}
|
||||
}
|
||||
|
||||
// WithAgentConnFunc overrides how workspace tools open logical connections to
|
||||
// workspace agents.
|
||||
func WithAgentConnFunc(agentConnFn workspacesdk.AgentConnFunc) func(*Deps) {
|
||||
return func(d *Deps) {
|
||||
d.agentConnFn = agentConnFn
|
||||
}
|
||||
}
|
||||
|
||||
// openAgentConn opens a ready workspace agent session for workspace inputs in
|
||||
// [owner/]workspace[.agent] format.
|
||||
func openAgentConn(ctx context.Context, deps Deps, workspace string) (workspacesdk.AgentConn, error) {
|
||||
if deps.coderClient == nil {
|
||||
return nil, xerrors.New("workspace tools require an authenticated client")
|
||||
}
|
||||
|
||||
workspaceName := NormalizeWorkspaceInput(workspace)
|
||||
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to find workspace: %w", err)
|
||||
}
|
||||
|
||||
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
|
||||
FetchInterval: 0,
|
||||
Fetch: deps.coderClient.WorkspaceAgent,
|
||||
FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter,
|
||||
// Always wait for startup scripts.
|
||||
Wait: true,
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("agent not ready: %w", err)
|
||||
}
|
||||
|
||||
conn, release, err := deps.agentConnFn(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to dial agent: %w", err)
|
||||
}
|
||||
|
||||
wrappedConn := workspacesdk.WrapAgentConn(conn, func() error {
|
||||
if release != nil {
|
||||
release()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if wrappedConn == nil {
|
||||
return nil, xerrors.New("agent connection function returned nil connection")
|
||||
}
|
||||
|
||||
return wrappedConn, nil
|
||||
}
|
||||
|
||||
// HandlerFunc is a typed function that handles a tool call.
|
||||
type HandlerFunc[Arg, Ret any] func(context.Context, Deps, Arg) (Ret, error)
|
||||
|
||||
@@ -1501,7 +1561,7 @@ var WorkspaceLS = Tool[WorkspaceLSArgs, WorkspaceLSResponse]{
|
||||
MCPAnnotations: mcpReadOnlyAnnotations,
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args WorkspaceLSArgs) (WorkspaceLSResponse, error) {
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return WorkspaceLSResponse{}, err
|
||||
}
|
||||
@@ -1567,7 +1627,7 @@ var WorkspaceReadFile = Tool[WorkspaceReadFileArgs, WorkspaceReadFileResponse]{
|
||||
MCPAnnotations: mcpReadOnlyAnnotations,
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args WorkspaceReadFileArgs) (WorkspaceReadFileResponse, error) {
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return WorkspaceReadFileResponse{}, err
|
||||
}
|
||||
@@ -1641,7 +1701,7 @@ content you are trying to write, then re-encode it properly.
|
||||
MCPAnnotations: mcpDestructiveAnnotations,
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args WorkspaceWriteFileArgs) (codersdk.Response, error) {
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return codersdk.Response{}, err
|
||||
}
|
||||
@@ -1716,7 +1776,7 @@ var WorkspaceEditFile = Tool[WorkspaceEditFileArgs, WorkspaceEditFilesResponse]{
|
||||
MCPAnnotations: mcpDestructiveAnnotations,
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args WorkspaceEditFileArgs) (WorkspaceEditFilesResponse, error) {
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return WorkspaceEditFilesResponse{}, err
|
||||
}
|
||||
@@ -1800,7 +1860,7 @@ var WorkspaceEditFiles = Tool[WorkspaceEditFilesArgs, WorkspaceEditFilesResponse
|
||||
MCPAnnotations: mcpDestructiveAnnotations,
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args WorkspaceEditFilesArgs) (WorkspaceEditFilesResponse, error) {
|
||||
conn, err := newAgentConn(ctx, deps.coderClient, args.Workspace)
|
||||
conn, err := openAgentConn(ctx, deps, args.Workspace)
|
||||
if err != nil {
|
||||
return WorkspaceEditFilesResponse{}, err
|
||||
}
|
||||
@@ -2245,41 +2305,6 @@ func NormalizeWorkspaceInput(input string) string {
|
||||
return normalized
|
||||
}
|
||||
|
||||
// newAgentConn returns a connection to the agent specified by the workspace,
|
||||
// which must be in the format [owner/]workspace[.agent].
|
||||
func newAgentConn(ctx context.Context, client *codersdk.Client, workspace string) (workspacesdk.AgentConn, error) {
|
||||
workspaceName := NormalizeWorkspaceInput(workspace)
|
||||
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, client, workspaceName)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to find workspace: %w", err)
|
||||
}
|
||||
|
||||
// Wait for agent to be ready.
|
||||
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
|
||||
FetchInterval: 0,
|
||||
Fetch: client.WorkspaceAgent,
|
||||
FetchLogs: client.WorkspaceAgentLogsAfter,
|
||||
Wait: true, // Always wait for startup scripts
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("agent not ready: %w", err)
|
||||
}
|
||||
|
||||
wsClient := workspacesdk.New(client)
|
||||
|
||||
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
BlockEndpoints: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to dial agent: %w", err)
|
||||
}
|
||||
|
||||
if !conn.AwaitReachable(ctx) {
|
||||
conn.Close()
|
||||
return nil, xerrors.New("agent connection not reachable")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
const workspaceDescription = "The workspace ID or name in the format [owner/]workspace. If an owner is not specified, the authenticated user is used."
|
||||
|
||||
const workspaceAgentDescription = "The workspace name in the format [owner/]workspace[.agent]. If an owner is not specified, the authenticated user is used."
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
agentapi "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/aisdk-go"
|
||||
@@ -61,6 +62,22 @@ func setupWorkspaceForAgent(t *testing.T, opts *coderdtest.Options) (*codersdk.C
|
||||
return userClient, r.Workspace, r.AgentToken
|
||||
}
|
||||
|
||||
type recordingAgentConnFunc struct {
|
||||
conn workspacesdk.AgentConn
|
||||
err error
|
||||
agentID uuid.UUID
|
||||
calls int
|
||||
}
|
||||
|
||||
func (d *recordingAgentConnFunc) AgentConn(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
d.calls++
|
||||
d.agentID = agentID
|
||||
if d.err != nil {
|
||||
return nil, nil, d.err
|
||||
}
|
||||
return d.conn, nil, nil
|
||||
}
|
||||
|
||||
// These tests are dependent on the state of the coder server.
|
||||
// Running them in parallel is prone to racy behavior.
|
||||
// nolint:tparallel,paralleltest
|
||||
@@ -597,6 +614,115 @@ func TestTools(t *testing.T) {
|
||||
}, res.Contents)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceToolsUseInjectedAgentConnFunc", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
|
||||
|
||||
ws, err := client.Workspace(t.Context(), workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ws.LatestBuild.Resources)
|
||||
require.NotEmpty(t, ws.LatestBuild.Resources[0].Agents)
|
||||
agentID := ws.LatestBuild.Resources[0].Agents[0].ID
|
||||
sentinelErr := xerrors.New("injected agent connection function used")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T, tb toolsdk.Deps) error
|
||||
}{
|
||||
{
|
||||
name: "WorkspaceLS",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceLS, tb, toolsdk.WorkspaceLSArgs{
|
||||
Workspace: workspace.Name,
|
||||
Path: "/tmp",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WorkspaceReadFile",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceReadFile, tb, toolsdk.WorkspaceReadFileArgs{
|
||||
Workspace: workspace.Name,
|
||||
Path: "/tmp/file",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WorkspaceWriteFile",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceWriteFile, tb, toolsdk.WorkspaceWriteFileArgs{
|
||||
Workspace: workspace.Name,
|
||||
Path: "/tmp/file",
|
||||
Content: []byte("hello from agent connection function"),
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WorkspaceEditFile",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceEditFile, tb, toolsdk.WorkspaceEditFileArgs{
|
||||
Workspace: workspace.Name,
|
||||
Path: "/tmp/file",
|
||||
Edits: []workspacesdk.FileEdit{{
|
||||
Search: "hello",
|
||||
Replace: "goodbye",
|
||||
}},
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WorkspaceEditFiles",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceEditFiles, tb, toolsdk.WorkspaceEditFilesArgs{
|
||||
Workspace: workspace.Name,
|
||||
Files: []workspacesdk.FileEdits{{
|
||||
Path: "/tmp/file",
|
||||
Edits: []workspacesdk.FileEdit{{
|
||||
Search: "hello",
|
||||
Replace: "goodbye",
|
||||
}},
|
||||
}},
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "WorkspaceBash",
|
||||
run: func(t *testing.T, tb toolsdk.Deps) error {
|
||||
_, err := testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{
|
||||
Workspace: workspace.Name,
|
||||
Command: "echo hello",
|
||||
})
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentConnFn := &recordingAgentConnFunc{err: sentinelErr}
|
||||
tb, err := toolsdk.NewDeps(client, toolsdk.WithAgentConnFunc(agentConnFn.AgentConn))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tt.run(t, tb)
|
||||
require.ErrorIs(t, err, sentinelErr)
|
||||
require.ErrorContains(t, err, "failed to dial agent")
|
||||
require.Equal(t, 1, agentConnFn.calls)
|
||||
require.Equal(t, agentID, agentConnFn.agentID)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WorkspaceReadFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -43,6 +44,40 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn {
|
||||
}
|
||||
}
|
||||
|
||||
// WrapAgentConn returns an AgentConn that delegates every operation to conn and
|
||||
// applies closeFunc exactly once when the logical session is closed.
|
||||
//
|
||||
// If conn is nil, any provided closeFunc is invoked immediately so logical
|
||||
// session cleanup is not silently dropped.
|
||||
func WrapAgentConn(conn AgentConn, closeFunc func() error) AgentConn {
|
||||
if conn == nil {
|
||||
if closeFunc != nil {
|
||||
_ = closeFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if closeFunc == nil {
|
||||
closeFunc = func() error { return nil }
|
||||
}
|
||||
return &wrappedAgentConn{AgentConn: conn, closeFunc: closeFunc}
|
||||
}
|
||||
|
||||
type wrappedAgentConn struct {
|
||||
AgentConn
|
||||
closeFunc func() error
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (c *wrappedAgentConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
// Close the underlying connection before releasing the logical session so
|
||||
// the lease remains held until teardown is complete.
|
||||
c.closeErr = errors.Join(c.AgentConn.Close(), c.closeFunc())
|
||||
})
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
const (
|
||||
// CoderChatIDHeader is the HTTP header containing the current
|
||||
// chat ID. Set by coderd on agentconn requests originating
|
||||
|
||||
@@ -175,6 +175,10 @@ func (c *Client) AgentConnectionInfo(ctx context.Context, agentID uuid.UUID) (Ag
|
||||
return connInfo, json.NewDecoder(res.Body).Decode(&connInfo)
|
||||
}
|
||||
|
||||
// AgentConnFunc returns a new connection to the specified agent. If release is
|
||||
// non-nil, callers must invoke it after they are done with the AgentConn.
|
||||
type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (conn AgentConn, release func(), err error)
|
||||
|
||||
// @typescript-ignore DialAgentOptions
|
||||
type DialAgentOptions struct {
|
||||
Logger slog.Logger
|
||||
|
||||
Reference in New Issue
Block a user