diff --git a/agent/agent.go b/agent/agent.go index ee33a20ca1..9ee013b084 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -50,6 +50,7 @@ import ( "github.com/coder/coder/v2/agent/proto/resourcesmonitor" "github.com/coder/coder/v2/agent/reconnectingpty" "github.com/coder/coder/v2/agent/x/agentdesktop" + "github.com/coder/coder/v2/agent/x/agentmcp" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/gitauth" "github.com/coder/coder/v2/coderd/database/dbtime" @@ -311,6 +312,8 @@ type agent struct { gitAPI *agentgit.API processAPI *agentproc.API desktopAPI *agentdesktop.API + mcpManager *agentmcp.Manager + mcpAPI *agentmcp.API socketServerEnabled bool socketPath string @@ -396,6 +399,8 @@ func (a *agent) init() { a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), ) a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock) + a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp")) + a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), a.sshServer, @@ -1348,6 +1353,14 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, } a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) a.scriptRunner.StartCron() + + // Connect to workspace MCP servers after the + // lifecycle transition to avoid delaying Ready. + // This runs inside the tracked goroutine so it + // is properly awaited on shutdown. + if mcpErr := a.mcpManager.Connect(a.gracefulCtx, manifest.Directory); mcpErr != nil { + a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr)) + } }) if err != nil { return xerrors.Errorf("track conn goroutine: %w", err) @@ -2070,6 +2083,10 @@ func (a *agent) Close() error { a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err)) } + if err := a.mcpManager.Close(); err != nil { + a.logger.Error(a.hardCtx, "mcp manager close", slog.Error(err)) + } + if a.boundaryLogProxy != nil { err = a.boundaryLogProxy.Close() if err != nil { diff --git a/agent/api.go b/agent/api.go index db21ca85cc..c74533e669 100644 --- a/agent/api.go +++ b/agent/api.go @@ -31,6 +31,7 @@ func (a *agent) apiHandler() http.Handler { r.Mount("/api/v0/git", a.gitAPI.Routes()) r.Mount("/api/v0/processes", a.processAPI.Routes()) r.Mount("/api/v0/desktop", a.desktopAPI.Routes()) + r.Mount("/api/v0/mcp", a.mcpAPI.Routes()) if a.devcontainers { r.Mount("/api/v0/containers", a.containerAPI.Routes()) diff --git a/agent/x/agentmcp/api.go b/agent/x/agentmcp/api.go new file mode 100644 index 0000000000..8582f68f00 --- /dev/null +++ b/agent/x/agentmcp/api.go @@ -0,0 +1,88 @@ +package agentmcp + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// API exposes MCP tool discovery and call proxying through the +// agent. +type API struct { + logger slog.Logger + manager *Manager +} + +// NewAPI creates a new MCP API handler backed by the given +// manager. +func NewAPI(logger slog.Logger, manager *Manager) *API { + return &API{ + logger: logger, + manager: manager, + } +} + +// Routes returns the HTTP handler for MCP-related routes. +func (api *API) Routes() http.Handler { + r := chi.NewRouter() + r.Get("/tools", api.handleListTools) + r.Post("/call-tool", api.handleCallTool) + return r +} + +// handleListTools returns the cached MCP tool definitions, +// optionally refreshing them first if ?refresh=true is set. +func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Allow callers to force a tool re-scan before listing. + if r.URL.Query().Get("refresh") == "true" { + if err := api.manager.RefreshTools(ctx); err != nil { + api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) + } + } + + tools := api.manager.Tools() + // Ensure non-nil so JSON serialization returns [] not null. + if tools == nil { + tools = []workspacesdk.MCPToolInfo{} + } + + httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListMCPToolsResponse{ + Tools: tools, + }) +} + +// handleCallTool proxies a tool invocation to the appropriate +// MCP server based on the tool name prefix. +func (api *API) handleCallTool(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var req workspacesdk.CallMCPToolRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + resp, err := api.manager.CallTool(ctx, req) + if err != nil { + status := http.StatusBadGateway + if errors.Is(err, ErrInvalidToolName) { + status = http.StatusBadRequest + } else if errors.Is(err, ErrUnknownServer) { + status = http.StatusNotFound + } + httpapi.Write(ctx, rw, status, codersdk.Response{ + Message: "MCP tool call failed.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, resp) +} diff --git a/agent/x/agentmcp/config.go b/agent/x/agentmcp/config.go new file mode 100644 index 0000000000..1899119157 --- /dev/null +++ b/agent/x/agentmcp/config.go @@ -0,0 +1,115 @@ +package agentmcp + +import ( + "encoding/json" + "os" + "slices" + "strings" + + "golang.org/x/xerrors" +) + +// ServerConfig describes a single MCP server parsed from a .mcp.json file. +type ServerConfig struct { + Name string `json:"name"` + Transport string `json:"type"` + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` +} + +// mcpConfigFile mirrors the on-disk .mcp.json schema. +type mcpConfigFile struct { + MCPServers map[string]json.RawMessage `json:"mcpServers"` +} + +// mcpServerEntry is a single server block inside mcpServers. +type mcpServerEntry struct { + Command string `json:"command"` + Args []string `json:"args"` + Env map[string]string `json:"env"` + Type string `json:"type"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` +} + +// ParseConfig reads a .mcp.json file at path and returns the declared +// MCP servers sorted by name. It returns an empty slice when the +// mcpServers key is missing or empty. +func ParseConfig(path string) ([]ServerConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, xerrors.Errorf("read mcp config %q: %w", path, err) + } + + var cfg mcpConfigFile + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, xerrors.Errorf("parse mcp config %q: %w", path, err) + } + + if len(cfg.MCPServers) == 0 { + return []ServerConfig{}, nil + } + + servers := make([]ServerConfig, 0, len(cfg.MCPServers)) + for name, raw := range cfg.MCPServers { + var entry mcpServerEntry + if err := json.Unmarshal(raw, &entry); err != nil { + return nil, xerrors.Errorf("parse server %q in %q: %w", name, path, err) + } + + if strings.Contains(name, ToolNameSep) || strings.HasPrefix(name, "_") || strings.HasSuffix(name, "_") { + return nil, xerrors.Errorf("server name %q in %q contains reserved separator %q or leading/trailing underscore", name, path, ToolNameSep) + } + + transport := inferTransport(entry) + + if transport == "" { + return nil, xerrors.Errorf("server %q in %q has no command or url", name, path) + } + + resolveEnvVars(entry.Env) + + servers = append(servers, ServerConfig{ + Name: name, + Transport: transport, + Command: entry.Command, + Args: entry.Args, + Env: entry.Env, + URL: entry.URL, + Headers: entry.Headers, + }) + } + + slices.SortFunc(servers, func(a, b ServerConfig) int { + return strings.Compare(a.Name, b.Name) + }) + + return servers, nil +} + +// inferTransport determines the transport type for a server entry. +// An explicit "type" field takes priority; otherwise the presence +// of "command" implies stdio and "url" implies http. +func inferTransport(e mcpServerEntry) string { + if e.Type != "" { + return e.Type + } + if e.Command != "" { + return "stdio" + } + if e.URL != "" { + return "http" + } + return "" +} + +// resolveEnvVars expands ${VAR} references in env map values +// using the current process environment. +func resolveEnvVars(env map[string]string) { + for k, v := range env { + env[k] = os.Expand(v, os.Getenv) + } +} diff --git a/agent/x/agentmcp/config_test.go b/agent/x/agentmcp/config_test.go new file mode 100644 index 0000000000..80466c959b --- /dev/null +++ b/agent/x/agentmcp/config_test.go @@ -0,0 +1,254 @@ +package agentmcp_test + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/x/agentmcp" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + expected []agentmcp.ServerConfig + expectError bool + }{ + { + name: "StdioServer", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "my-server": map[string]any{ + "command": "npx", + "args": []string{"-y", "@example/mcp-server"}, + "env": map[string]string{"FOO": "bar"}, + }, + }, + }), + expected: []agentmcp.ServerConfig{ + { + Name: "my-server", + Transport: "stdio", + Command: "npx", + Args: []string{"-y", "@example/mcp-server"}, + Env: map[string]string{"FOO": "bar"}, + }, + }, + }, + { + name: "HTTPServer", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "remote": map[string]any{ + "url": "https://example.com/mcp", + "headers": map[string]string{"Authorization": "Bearer tok"}, + }, + }, + }), + expected: []agentmcp.ServerConfig{ + { + Name: "remote", + Transport: "http", + URL: "https://example.com/mcp", + Headers: map[string]string{"Authorization": "Bearer tok"}, + }, + }, + }, + { + name: "SSEServer", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "events": map[string]any{ + "type": "sse", + "url": "https://example.com/sse", + }, + }, + }), + expected: []agentmcp.ServerConfig{ + { + Name: "events", + Transport: "sse", + URL: "https://example.com/sse", + }, + }, + }, + { + name: "ExplicitTypeOverridesInference", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "hybrid": map[string]any{ + "command": "some-binary", + "type": "http", + }, + }, + }), + expected: []agentmcp.ServerConfig{ + { + Name: "hybrid", + Transport: "http", + Command: "some-binary", + }, + }, + }, + { + name: "EnvVarPassthrough", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "srv": map[string]any{ + "command": "run", + "env": map[string]string{"PLAIN": "literal-value"}, + }, + }, + }), + expected: []agentmcp.ServerConfig{ + { + Name: "srv", + Transport: "stdio", + Command: "run", + Env: map[string]string{"PLAIN": "literal-value"}, + }, + }, + }, + { + name: "EmptyMCPServers", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{}, + }), + expected: []agentmcp.ServerConfig{}, + }, + { + name: "MalformedJSON", + content: `{not valid json`, + expectError: true, + }, + { + name: "ServerNameContainsSeparator", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "bad__name": map[string]any{"command": "run"}, + }, + }), + expectError: true, + }, + { + name: "ServerNameTrailingUnderscore", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "server_": map[string]any{"command": "run"}, + }, + }), + expectError: true, + }, + { + name: "ServerNameLeadingUnderscore", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "_server": map[string]any{"command": "run"}, + }, + }), + expectError: true, + }, + { + name: "EmptyTransport", content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "empty": map[string]any{}, + }, + }), + expectError: true, + }, + { + name: "MissingMCPServersKey", + content: mustJSON(t, map[string]any{ + "servers": map[string]any{}, + }), + expected: []agentmcp.ServerConfig{}, + }, + { + name: "MultipleServersSortedByName", + content: mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "zeta": map[string]any{"command": "z"}, + "alpha": map[string]any{"command": "a"}, + "mu": map[string]any{"command": "m"}, + }, + }), + expected: []agentmcp.ServerConfig{ + {Name: "alpha", Transport: "stdio", Command: "a"}, + {Name: "mu", Transport: "stdio", Command: "m"}, + {Name: "zeta", Transport: "stdio", Command: "z"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, ".mcp.json") + err := os.WriteFile(path, []byte(tt.content), 0o600) + require.NoError(t, err) + + got, err := agentmcp.ParseConfig(path) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, got) + }) + } +} + +// TestParseConfig_EnvVarInterpolation verifies that ${VAR} references +// in env values are resolved from the process environment. This test +// cannot be parallel because t.Setenv is incompatible with t.Parallel. +func TestParseConfig_EnvVarInterpolation(t *testing.T) { + t.Setenv("TEST_MCP_TOKEN", "secret123") + + content := mustJSON(t, map[string]any{ + "mcpServers": map[string]any{ + "srv": map[string]any{ + "command": "run", + "env": map[string]string{"TOKEN": "${TEST_MCP_TOKEN}"}, + }, + }, + }) + + dir := t.TempDir() + path := filepath.Join(dir, ".mcp.json") + err := os.WriteFile(path, []byte(content), 0o600) + require.NoError(t, err) + + got, err := agentmcp.ParseConfig(path) + require.NoError(t, err) + require.Equal(t, []agentmcp.ServerConfig{ + { + Name: "srv", + Transport: "stdio", + Command: "run", + Env: map[string]string{"TOKEN": "secret123"}, + }, + }, got) +} + +func TestParseConfig_FileNotFound(t *testing.T) { + t.Parallel() + + _, err := agentmcp.ParseConfig(filepath.Join(t.TempDir(), "nonexistent.json")) + require.Error(t, err) +} + +// mustJSON marshals v to a JSON string, failing the test on error. +func mustJSON(t *testing.T, v any) string { + t.Helper() + data, err := json.Marshal(v) + require.NoError(t, err) + return string(data) +} diff --git a/agent/x/agentmcp/manager.go b/agent/x/agentmcp/manager.go new file mode 100644 index 0000000000..f8c4445cce --- /dev/null +++ b/agent/x/agentmcp/manager.go @@ -0,0 +1,447 @@ +package agentmcp + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + "strings" + "sync" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// ToolNameSep separates the server name from the original tool name +// in prefixed tool names. Double underscore avoids collisions with +// tool names that may contain single underscores. +const ToolNameSep = "__" + +// connectTimeout bounds how long we wait for a single MCP server +// to start its transport and complete initialization. +const connectTimeout = 30 * time.Second + +// toolCallTimeout bounds how long a single tool invocation may +// take before being canceled. +const toolCallTimeout = 60 * time.Second + +var ( + // ErrInvalidToolName is returned when the tool name format + // is not "server__tool". + ErrInvalidToolName = xerrors.New("invalid tool name format") + // ErrUnknownServer is returned when no MCP server matches + // the prefix in the tool name. + ErrUnknownServer = xerrors.New("unknown MCP server") +) + +// Manager manages connections to MCP servers discovered from a +// workspace's .mcp.json file. It caches the aggregated tool list +// and proxies tool calls to the appropriate server. +type Manager struct { + mu sync.RWMutex + logger slog.Logger + closed bool + servers map[string]*serverEntry // keyed by server name + tools []workspacesdk.MCPToolInfo +} + +// serverEntry pairs a server config with its connected client. +type serverEntry struct { + config ServerConfig + client *client.Client +} + +// NewManager creates a new MCP client manager. +func NewManager(logger slog.Logger) *Manager { + return &Manager{ + logger: logger, + servers: make(map[string]*serverEntry), + } +} + +// Connect discovers .mcp.json in dir and connects to all +// configured servers. Failed servers are logged and skipped. +func (m *Manager) Connect(ctx context.Context, dir string) error { + path := filepath.Join(dir, ".mcp.json") + configs, err := ParseConfig(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return xerrors.Errorf("parse mcp config: %w", err) + } + + // Connect to servers in parallel without holding the + // lock, since each connectServer call may block on + // network I/O for up to connectTimeout. + type connectedServer struct { + name string + config ServerConfig + client *client.Client + } + var ( + mu sync.Mutex + connected []connectedServer + ) + var eg errgroup.Group + for _, cfg := range configs { + eg.Go(func() error { + c, err := m.connectServer(ctx, cfg) + if err != nil { + m.logger.Warn(ctx, "skipping MCP server", + slog.F("server", cfg.Name), + slog.F("transport", cfg.Transport), + slog.Error(err), + ) + return nil // Don't fail the group. + } + mu.Lock() + connected = append(connected, connectedServer{ + name: cfg.Name, config: cfg, client: c, + }) + mu.Unlock() + return nil + }) + } + _ = eg.Wait() + + m.mu.Lock() + if m.closed { + m.mu.Unlock() + // Close the freshly-connected clients since we're + // shutting down. + for _, cs := range connected { + _ = cs.client.Close() + } + return xerrors.New("manager closed") + } + + // Close previous connections to avoid leaking child + // processes on agent reconnect. + for _, entry := range m.servers { + _ = entry.client.Close() + } + m.servers = make(map[string]*serverEntry, len(connected)) + + for _, cs := range connected { + m.servers[cs.name] = &serverEntry{ + config: cs.config, + client: cs.client, + } + } + m.mu.Unlock() + + // Refresh tools outside the lock to avoid blocking + // concurrent reads during network I/O. + if err := m.RefreshTools(ctx); err != nil { + m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) + } + return nil +} + +// connectServer establishes a connection to a single MCP server +// and returns the connected client. It does not modify any Manager +// state. +func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) { + tr, err := createTransport(cfg) + if err != nil { + return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err) + } + + c := client.NewClient(tr) + + connectCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + if err := c.Start(connectCtx); err != nil { + _ = c.Close() + return nil, xerrors.Errorf("start %q: %w", cfg.Name, err) + } + + _, err = c.Initialize(connectCtx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "coder-agent", + Version: buildinfo.Version(), + }, + }, + }) + if err != nil { + _ = c.Close() + return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err) + } + + return c, nil +} + +// createTransport builds the mcp-go transport for a server config. +func createTransport(cfg ServerConfig) (transport.Interface, error) { + switch cfg.Transport { + case "stdio": + return transport.NewStdio( + cfg.Command, + buildEnv(cfg.Env), + cfg.Args..., + ), nil + case "http", "": + return transport.NewStreamableHTTP( + cfg.URL, + transport.WithHTTPHeaders(cfg.Headers), + ) + case "sse": + return transport.NewSSE( + cfg.URL, + transport.WithHeaders(cfg.Headers), + ) + default: + return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport) + } +} + +// buildEnv merges the current process environment with explicit +// overrides, returning the result as KEY=VALUE strings suitable +// for the stdio transport. +func buildEnv(explicit map[string]string) []string { + env := os.Environ() + if len(explicit) == 0 { + return env + } + + // Index existing env so explicit keys can override in-place. + existing := make(map[string]int, len(env)) + for i, kv := range env { + if k, _, ok := strings.Cut(kv, "="); ok { + existing[k] = i + } + } + + for k, v := range explicit { + entry := k + "=" + v + if idx, ok := existing[k]; ok { + env[idx] = entry + } else { + env = append(env, entry) + } + } + return env +} + +// Tools returns the cached tool list. Thread-safe. +func (m *Manager) Tools() []workspacesdk.MCPToolInfo { + m.mu.RLock() + defer m.mu.RUnlock() + + return slices.Clone(m.tools) +} + +// CallTool proxies a tool call to the appropriate MCP server. +func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + serverName, originalName, err := splitToolName(req.ToolName) + if err != nil { + return workspacesdk.CallMCPToolResponse{}, err + } + + m.mu.RLock() + entry, ok := m.servers[serverName] + m.mu.RUnlock() + + if !ok { + return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("%w: %q", ErrUnknownServer, serverName) + } + + callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout) + defer cancel() + + result, err := entry.client.CallTool(callCtx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: originalName, + Arguments: req.Arguments, + }, + }) + if err != nil { + return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("call tool %q on %q: %w", originalName, serverName, err) + } + + return convertResult(result), nil +} + +// splitToolName extracts the server name and original tool name +// from a prefixed tool name like "server__tool". +func splitToolName(prefixed string) (serverName, toolName string, err error) { + server, tool, ok := strings.Cut(prefixed, ToolNameSep) + if !ok || server == "" || tool == "" { + return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed) + } + return server, tool, nil +} + +// convertResult translates an MCP CallToolResult into a +// workspacesdk.CallMCPToolResponse. It iterates over content +// items and maps each recognized type. +func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse { + if result == nil { + return workspacesdk.CallMCPToolResponse{} + } + + var content []workspacesdk.MCPToolContent + for _, item := range result.Content { + switch c := item.(type) { + case mcp.TextContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "text", + Text: c.Text, + }) + case mcp.ImageContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "image", + Data: c.Data, + MediaType: c.MIMEType, + }) + case mcp.AudioContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "audio", + Data: c.Data, + MediaType: c.MIMEType, + }) + case mcp.EmbeddedResource: + content = append(content, workspacesdk.MCPToolContent{ + Type: "resource", + Text: fmt.Sprintf("[embedded resource: %T]", c.Resource), + }) + case mcp.ResourceLink: + content = append(content, workspacesdk.MCPToolContent{ + Type: "resource", + Text: fmt.Sprintf("[resource link: %s]", c.URI), + }) + default: + content = append(content, workspacesdk.MCPToolContent{ + Type: "text", + Text: fmt.Sprintf("[unsupported content type: %T]", item), + }) + } + } + + return workspacesdk.CallMCPToolResponse{ + Content: content, + IsError: result.IsError, + } +} + +// RefreshTools re-fetches tool lists from all connected servers +// in parallel and rebuilds the cache. On partial failure, tools +// from servers that responded successfully are merged with the +// existing cached tools for servers that failed, so a single +// dead server doesn't block updates from healthy ones. +func (m *Manager) RefreshTools(ctx context.Context) error { + // Snapshot servers under read lock. + m.mu.RLock() + servers := make(map[string]*serverEntry, len(m.servers)) + for k, v := range m.servers { + servers[k] = v + } + m.mu.RUnlock() + + // Fetch tool lists in parallel without holding any lock. + type serverTools struct { + name string + tools []workspacesdk.MCPToolInfo + } + var ( + mu sync.Mutex + results []serverTools + failed []string + errs []error + ) + var eg errgroup.Group + for name, entry := range servers { + eg.Go(func() error { + listCtx, cancel := context.WithTimeout(ctx, connectTimeout) + result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{}) + cancel() + if err != nil { + m.logger.Warn(ctx, "failed to list tools from MCP server", + slog.F("server", name), + slog.Error(err), + ) + mu.Lock() + errs = append(errs, xerrors.Errorf("list tools from %q: %w", name, err)) + failed = append(failed, name) + mu.Unlock() + return nil + } + var tools []workspacesdk.MCPToolInfo + for _, tool := range result.Tools { + tools = append(tools, workspacesdk.MCPToolInfo{ + ServerName: name, + Name: name + ToolNameSep + tool.Name, + Description: tool.Description, + Schema: tool.InputSchema.Properties, + Required: tool.InputSchema.Required, + }) + } + mu.Lock() + results = append(results, serverTools{name: name, tools: tools}) + mu.Unlock() + return nil + }) + } + _ = eg.Wait() + + // Build the new tool list. For servers that failed, preserve + // their tools from the existing cache so a single dead server + // doesn't remove healthy tools. + var merged []workspacesdk.MCPToolInfo + for _, st := range results { + merged = append(merged, st.tools...) + } + if len(failed) > 0 { + failedSet := make(map[string]struct{}, len(failed)) + for _, f := range failed { + failedSet[f] = struct{}{} + } + m.mu.RLock() + for _, t := range m.tools { + if _, ok := failedSet[t.ServerName]; ok { + merged = append(merged, t) + } + } + m.mu.RUnlock() + } + slices.SortFunc(merged, func(a, b workspacesdk.MCPToolInfo) int { + return strings.Compare(a.Name, b.Name) + }) + + m.mu.Lock() + m.tools = merged + m.mu.Unlock() + + return errors.Join(errs...) +} + +// Close terminates all MCP server connections and child +// processes. +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + m.closed = true + var errs []error + for _, entry := range m.servers { + errs = append(errs, entry.client.Close()) + } + m.servers = make(map[string]*serverEntry) + m.tools = nil + return errors.Join(errs...) +} diff --git a/agent/x/agentmcp/manager_internal_test.go b/agent/x/agentmcp/manager_internal_test.go new file mode 100644 index 0000000000..3915b2126e --- /dev/null +++ b/agent/x/agentmcp/manager_internal_test.go @@ -0,0 +1,195 @@ +package agentmcp + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +func TestSplitToolName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantServer string + wantTool string + wantErr bool + }{ + { + name: "Valid", + input: "server__tool", + wantServer: "server", + wantTool: "tool", + }, + { + name: "ValidWithUnderscoresInTool", + input: "server__my_tool", + wantServer: "server", + wantTool: "my_tool", + }, + { + name: "MissingSeparator", + input: "servertool", + wantErr: true, + }, + { + name: "EmptyServer", + input: "__tool", + wantErr: true, + }, + { + name: "EmptyTool", + input: "server__", + wantErr: true, + }, + { + name: "JustSeparator", + input: "__", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server, tool, err := splitToolName(tt.input) + if tt.wantErr { + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidToolName) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantServer, server) + assert.Equal(t, tt.wantTool, tool) + }) + } +} + +func TestConvertResult(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + // input is a pointer so we can test nil. + input *mcp.CallToolResult + want workspacesdk.CallMCPToolResponse + }{ + { + name: "NilInput", + input: nil, + want: workspacesdk.CallMCPToolResponse{}, + }, + { + name: "TextContent", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{Type: "text", Text: "hello"}, + }, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "text", Text: "hello"}, + }, + }, + }, + { + name: "ImageContent", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.ImageContent{ + Type: "image", + Data: "base64data", + MIMEType: "image/png", + }, + }, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "image", Data: "base64data", MediaType: "image/png"}, + }, + }, + }, + { + name: "AudioContent", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.AudioContent{ + Type: "audio", + Data: "base64audio", + MIMEType: "audio/mp3", + }, + }, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "audio", Data: "base64audio", MediaType: "audio/mp3"}, + }, + }, + }, + { + name: "IsErrorPropagation", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{Type: "text", Text: "fail"}, + }, + IsError: true, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "text", Text: "fail"}, + }, + IsError: true, + }, + }, + { + name: "MultipleContentItems", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{Type: "text", Text: "caption"}, + mcp.ImageContent{ + Type: "image", + Data: "imgdata", + MIMEType: "image/jpeg", + }, + }, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "text", Text: "caption"}, + {Type: "image", Data: "imgdata", MediaType: "image/jpeg"}, + }, + }, + }, + { + name: "ResourceLink", + input: &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.ResourceLink{ + Type: "resource_link", + URI: "file:///tmp/test.txt", + }, + }, + }, + want: workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{ + {Type: "resource", Text: "[resource link: file:///tmp/test.txt]"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := convertResult(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 1669e8536f..2c31176e4a 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -116,6 +116,11 @@ type Server struct { // never contend with each other. chatStreams sync.Map // uuid.UUID -> *chatStreamState + // workspaceMCPToolsCache caches workspace MCP tool definitions + // per chat to avoid re-fetching on every turn. The cache is + // keyed by chat ID and invalidated when the agent changes. + workspaceMCPToolsCache sync.Map // uuid.UUID -> *cachedWorkspaceMCPTools + usageTracker *workspacestats.UsageTracker clock quartz.Clock @@ -156,6 +161,13 @@ func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool { return m } +// cachedWorkspaceMCPTools stores workspace MCP tools discovered +// from a workspace agent, keyed by the agent ID that provided them. +type cachedWorkspaceMCPTools struct { + agentID uuid.UUID + tools []workspacesdk.MCPToolInfo +} + type turnWorkspaceContext struct { server *Server chatStateMu *sync.Mutex @@ -2020,6 +2032,7 @@ func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState { func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) { if !state.buffering && len(state.subscribers) == 0 { p.chatStreams.Delete(chatID) + p.workspaceMCPToolsCache.Delete(chatID) } } @@ -3240,6 +3253,7 @@ func (p *Server) runChat( resolvedUserPrompt string mcpTools []fantasy.AgentTool mcpCleanup func() + workspaceMCPTools []fantasy.AgentTool ) // Check if instruction files need to be (re-)persisted. // This happens when no context-file parts exist yet, or when @@ -3295,6 +3309,62 @@ func (p *Server) runChat( return nil }) } + if chat.WorkspaceID.Valid { + g2.Go(func() error { + // Check cache first. On subsequent turns with the same + // agent, reuse cached tools to avoid a round-trip. + if cached, ok := p.workspaceMCPToolsCache.Load(chat.ID); ok { + entry, ok2 := cached.(*cachedWorkspaceMCPTools) + if !ok2 { + return nil + } + // Verify the agent hasn't changed. + if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID == entry.agentID { + for _, t := range entry.tools { + workspaceMCPTools = append(workspaceMCPTools, + chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), + ) + } + return nil + } + } + + // Cache miss or agent changed — fetch fresh tools. + conn, connErr := workspaceCtx.getWorkspaceConn(ctx) + if connErr != nil { + logger.Warn(ctx, "failed to get workspace conn for MCP tools", + slog.Error(connErr)) + return nil + } + toolsResp, listErr := conn.ListMCPTools(ctx) + if listErr != nil { + logger.Warn(ctx, "failed to list workspace MCP tools", + slog.Error(listErr)) + return nil + } + + // Cache the result for subsequent turns. Skip + // caching when the list is empty because the + // agent's MCP Connect may not have finished yet; + // caching an empty list would hide tools + // permanently. + if len(toolsResp.Tools) > 0 { + if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil { + p.workspaceMCPToolsCache.Store(chat.ID, &cachedWorkspaceMCPTools{ + agentID: agent.ID, + tools: toolsResp.Tools, + }) + } + } + + for _, t := range toolsResp.Tools { + workspaceMCPTools = append(workspaceMCPTools, + chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), + ) + } + return nil + }) + } // All g2 goroutines return nil; error is discarded. _ = g2.Wait() if mcpCleanup != nil { @@ -3713,6 +3783,7 @@ func (p *Server) runChat( // after the built-in tools so the LLM sees them as // additional capabilities. tools = append(tools, mcpTools...) + tools = append(tools, workspaceMCPTools...) // Build provider-native tools (e.g., web search) based on // the model configuration. diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index b3f846117b..783d268c22 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -1551,6 +1551,10 @@ func TestPersistToolResultWithBinaryData(t *testing.T) { mockConn.EXPECT(). SetExtraHeaders(gomock.Any()). AnyTimes() + mockConn.EXPECT(). + ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{}, nil). + AnyTimes() mockConn.EXPECT(). LS(gomock.Any(), gomock.Any(), gomock.Any()). Return(workspacesdk.LSResponse{}, nil). @@ -3151,6 +3155,10 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { // for the initial screenshot check in the computer use path. ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) + mockConn.EXPECT(). + ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{}, nil). + AnyTimes() mockConn.EXPECT(). ExecuteDesktopAction(gomock.Any(), gomock.Any()). Return(workspacesdk.DesktopActionResponse{ @@ -3595,6 +3603,8 @@ func TestMCPServerToolInvocation(t *testing.T) { ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes() mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). Return(workspacesdk.LSResponse{}, nil).AnyTimes() mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). diff --git a/coderd/x/chatd/chattool/mcpworkspace.go b/coderd/x/chatd/chattool/mcpworkspace.go new file mode 100644 index 0000000000..dc175ab43e --- /dev/null +++ b/coderd/x/chatd/chattool/mcpworkspace.go @@ -0,0 +1,151 @@ +package chattool + +import ( + "context" + "encoding/base64" + "encoding/json" + "strings" + + "charm.land/fantasy" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// WorkspaceMCPTool wraps a single MCP tool discovered in a +// workspace, proxying calls through the workspace agent +// connection. It implements fantasy.AgentTool so it can be +// registered alongside built-in chat tools. +type WorkspaceMCPTool struct { + info fantasy.ToolInfo + getConn func(context.Context) (workspacesdk.AgentConn, error) + providerOpts fantasy.ProviderOptions +} + +// NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo +// discovered on a workspace agent. Each tool proxies calls back +// through the agent connection. +func NewWorkspaceMCPTool( + tool workspacesdk.MCPToolInfo, + getConn func(context.Context) (workspacesdk.AgentConn, error), +) *WorkspaceMCPTool { + required := tool.Required + if required == nil { + required = []string{} + } + return &WorkspaceMCPTool{ + info: fantasy.ToolInfo{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Schema, + Required: required, + Parallel: true, + }, + getConn: getConn, + } +} + +func (t *WorkspaceMCPTool) Info() fantasy.ToolInfo { + return t.info +} + +func (t *WorkspaceMCPTool) Run( + ctx context.Context, + params fantasy.ToolCall, +) (fantasy.ToolResponse, error) { + conn, err := t.getConn(ctx) + if err != nil { + return fantasy.NewTextErrorResponse( + "workspace connection failed: " + err.Error(), + ), nil + } + + var args map[string]any + if params.Input != "" { + if err := json.Unmarshal( + []byte(params.Input), &args, + ); err != nil { + return fantasy.NewTextErrorResponse( + "invalid JSON input: " + err.Error(), + ), nil + } + } + + resp, err := conn.CallMCPTool(ctx, workspacesdk.CallMCPToolRequest{ + ToolName: t.info.Name, + Arguments: args, + }) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + return convertMCPToolResponse(resp), nil +} + +func (t *WorkspaceMCPTool) ProviderOptions() fantasy.ProviderOptions { + return t.providerOpts +} + +func (t *WorkspaceMCPTool) SetProviderOptions( + opts fantasy.ProviderOptions, +) { + t.providerOpts = opts +} + +// convertMCPToolResponse translates a workspace agent MCP tool +// response into a fantasy.ToolResponse. Text content blocks are +// collected and joined; binary content (image/media) is returned +// only when no text is available, matching the mcpclient +// conversion strategy. +func convertMCPToolResponse( + resp workspacesdk.CallMCPToolResponse, +) fantasy.ToolResponse { + var ( + textParts []string + binaryResult *fantasy.ToolResponse + ) + + for _, c := range resp.Content { + switch c.Type { + case "text": + textParts = append(textParts, c.Text) + case "image", "audio": + if c.Data == "" { + continue + } + data, err := base64.StdEncoding.DecodeString(c.Data) + if err != nil { + textParts = append(textParts, + "[binary decode error: "+err.Error()+"]", + ) + continue + } + if binaryResult == nil { + r := fantasy.ToolResponse{ + Type: c.Type, + Data: data, + MediaType: c.MediaType, + IsError: resp.IsError, + } + binaryResult = &r + } + default: + textParts = append(textParts, c.Text) + } + } + + // Prefer text content. Only fall back to binary when no + // text was collected. + if len(textParts) > 0 { + r := fantasy.NewTextResponse( + strings.Join(textParts, "\n"), + ) + r.IsError = resp.IsError + return r + } + if binaryResult != nil { + return *binaryResult + } + r := fantasy.NewTextResponse("") + r.IsError = resp.IsError + return r +} diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index acb22e1d29..c8ee83894d 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -59,6 +59,7 @@ type AgentConn interface { SetExtraHeaders(h http.Header) AwaitReachable(ctx context.Context) bool + CallMCPTool(ctx context.Context, req CallMCPToolRequest) (CallMCPToolResponse, error) Close() error DebugLogs(ctx context.Context) ([]byte, error) DebugMagicsock(ctx context.Context) ([]byte, error) @@ -66,6 +67,7 @@ type AgentConn interface { DialContext(ctx context.Context, network string, addr string) (net.Conn, error) GetPeerDiagnostics() tailnet.PeerDiagnostics ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) + ListMCPTools(ctx context.Context) (ListMCPToolsResponse, error) ListProcesses(ctx context.Context) (ListProcessesResponse, error) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) @@ -923,6 +925,50 @@ type FileEditRequest struct { Files []FileEdits `json:"files"` } +// ListMCPToolsResponse is the response from the agent's +// MCP tool discovery endpoint. +type ListMCPToolsResponse struct { + Tools []MCPToolInfo `json:"tools"` +} + +// MCPToolInfo describes a single tool discovered from an MCP +// server configured in the workspace's .mcp.json file. +type MCPToolInfo struct { + // ServerName is the key from .mcp.json (e.g. "github"). + ServerName string `json:"server_name"` + // Name is the prefixed tool name: "serverName__toolName". + Name string `json:"name"` + // Description is the tool's human-readable description. + Description string `json:"description"` + // Schema is the JSON Schema for the tool's input parameters. + Schema map[string]any `json:"schema"` + // Required lists required parameter names. + Required []string `json:"required"` +} + +// CallMCPToolRequest is the request body for proxying an MCP +// tool call through the workspace agent. +type CallMCPToolRequest struct { + // ToolName is the prefixed tool name (e.g. "github__create_issue"). + ToolName string `json:"tool_name"` + // Arguments is the tool input as key-value pairs. + Arguments map[string]any `json:"arguments"` +} + +// CallMCPToolResponse is the response from a proxied MCP tool call. +type CallMCPToolResponse struct { + Content []MCPToolContent `json:"content"` + IsError bool `json:"is_error"` +} + +// MCPToolContent is a single content block in an MCP tool response. +type MCPToolContent struct { + Type string `json:"type"` // "text", "image", "audio", "resource" + Text string `json:"text,omitempty"` + Data string `json:"data,omitempty"` // base64 for binary + MediaType string `json:"media_type,omitempty"` +} + // StartProcess starts a new process on the workspace agent. func (c *agentConn) StartProcess(ctx context.Context, req StartProcessRequest) (StartProcessResponse, error) { ctx, span := tracing.StartSpan(ctx) @@ -955,6 +1001,40 @@ func (c *agentConn) ListProcesses(ctx context.Context) (ListProcessesResponse, e return resp, json.NewDecoder(res.Body).Decode(&resp) } +// ListMCPTools returns tools discovered from MCP servers configured +// in the workspace. +func (c *agentConn) ListMCPTools(ctx context.Context) (ListMCPToolsResponse, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/mcp/tools", nil) + if err != nil { + return ListMCPToolsResponse{}, xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ListMCPToolsResponse{}, codersdk.ReadBodyAsError(res) + } + var resp ListMCPToolsResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// CallMCPTool proxies a tool call to an MCP server running in +// the workspace. +func (c *agentConn) CallMCPTool(ctx context.Context, req CallMCPToolRequest) (CallMCPToolResponse, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/mcp/call-tool", req) + if err != nil { + return CallMCPToolResponse{}, xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return CallMCPToolResponse{}, codersdk.ReadBodyAsError(res) + } + var resp CallMCPToolResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + // ProcessOutput returns the output of a tracked process on the agent. func (c *agentConn) ProcessOutput(ctx context.Context, id string, opts *ProcessOutputOptions) (ProcessOutputResponse, error) { ctx, span := tracing.StartSpan(ctx) diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index 6a83a8bd8d..f3860e70c6 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -69,6 +69,21 @@ func (mr *MockAgentConnMockRecorder) AwaitReachable(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitReachable", reflect.TypeOf((*MockAgentConn)(nil).AwaitReachable), ctx) } +// CallMCPTool mocks base method. +func (m *MockAgentConn) CallMCPTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CallMCPTool", ctx, req) + ret0, _ := ret[0].(workspacesdk.CallMCPToolResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CallMCPTool indicates an expected call of CallMCPTool. +func (mr *MockAgentConnMockRecorder) CallMCPTool(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallMCPTool", reflect.TypeOf((*MockAgentConn)(nil).CallMCPTool), ctx, req) +} + // Close mocks base method. func (m *MockAgentConn) Close() error { m.ctrl.T.Helper() @@ -245,6 +260,21 @@ func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx) } +// ListMCPTools mocks base method. +func (m *MockAgentConn) ListMCPTools(ctx context.Context) (workspacesdk.ListMCPToolsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListMCPTools", ctx) + ret0, _ := ret[0].(workspacesdk.ListMCPToolsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListMCPTools indicates an expected call of ListMCPTools. +func (mr *MockAgentConnMockRecorder) ListMCPTools(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListMCPTools", reflect.TypeOf((*MockAgentConn)(nil).ListMCPTools), ctx) +} + // ListProcesses mocks base method. func (m *MockAgentConn) ListProcesses(ctx context.Context) (workspacesdk.ListProcessesResponse, error) { m.ctrl.T.Helper()