feat: add workspace MCP tool discovery and proxying for chat (#23680)

Coder's chat (chatd) can now discover and use MCP servers configured in
a workspace's `.mcp.json` file. This brings project-specific tooling
(GitHub, databases, docs servers, etc.) into the chat without any manual
configuration.

## How it works

The workspace agent reads `.mcp.json` from the workspace directory (same
format Claude Code uses), connects to the declared MCP servers —
spawning child processes for stdio servers and connecting over the
network for HTTP/SSE — and caches their tool lists. Two new agent HTTP
endpoints expose this:

- `GET /api/v0/mcp/tools` returns the cached tool list (supports
`?refresh=true`)
- `POST /api/v0/mcp/call-tool` proxies calls to the correct server

On each chat turn, chatd calls `ListMCPTools` through the existing
`AgentConn` tailnet connection, wraps each tool as a
`fantasy.AgentTool`, and adds them to the LLM's tool set alongside
built-in and admin-configured MCP tools. Tool names are prefixed with
the server name (`github__create_issue`) to avoid collisions.

Failed server connections are logged and skipped — they never block the
agent or break the chat. Child stdio processes are terminated on agent
shutdown.
This commit is contained in:
Kyle Carberry
2026-03-26 15:57:02 -04:00
committed by GitHub
parent 02b58534a0
commit 0f86c4237e
12 changed files with 1459 additions and 0 deletions
+17
View File
@@ -50,6 +50,7 @@ import (
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
"github.com/coder/coder/v2/agent/reconnectingpty"
"github.com/coder/coder/v2/agent/x/agentdesktop"
"github.com/coder/coder/v2/agent/x/agentmcp"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/gitauth"
"github.com/coder/coder/v2/coderd/database/dbtime"
@@ -311,6 +312,8 @@ type agent struct {
gitAPI *agentgit.API
processAPI *agentproc.API
desktopAPI *agentdesktop.API
mcpManager *agentmcp.Manager
mcpAPI *agentmcp.API
socketServerEnabled bool
socketPath string
@@ -396,6 +399,8 @@ func (a *agent) init() {
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
@@ -1348,6 +1353,14 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
}
a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur)
a.scriptRunner.StartCron()
// Connect to workspace MCP servers after the
// lifecycle transition to avoid delaying Ready.
// This runs inside the tracked goroutine so it
// is properly awaited on shutdown.
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, manifest.Directory); mcpErr != nil {
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
}
})
if err != nil {
return xerrors.Errorf("track conn goroutine: %w", err)
@@ -2070,6 +2083,10 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
}
if err := a.mcpManager.Close(); err != nil {
a.logger.Error(a.hardCtx, "mcp manager close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
+1
View File
@@ -31,6 +31,7 @@ func (a *agent) apiHandler() http.Handler {
r.Mount("/api/v0/git", a.gitAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
r.Mount("/api/v0/mcp", a.mcpAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
+88
View File
@@ -0,0 +1,88 @@
package agentmcp
import (
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// API exposes MCP tool discovery and call proxying through the
// agent.
type API struct {
logger slog.Logger
manager *Manager
}
// NewAPI creates a new MCP API handler backed by the given
// manager.
func NewAPI(logger slog.Logger, manager *Manager) *API {
return &API{
logger: logger,
manager: manager,
}
}
// Routes returns the HTTP handler for MCP-related routes.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/tools", api.handleListTools)
r.Post("/call-tool", api.handleCallTool)
return r
}
// handleListTools returns the cached MCP tool definitions,
// optionally refreshing them first if ?refresh=true is set.
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Allow callers to force a tool re-scan before listing.
if r.URL.Query().Get("refresh") == "true" {
if err := api.manager.RefreshTools(ctx); err != nil {
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
}
}
tools := api.manager.Tools()
// Ensure non-nil so JSON serialization returns [] not null.
if tools == nil {
tools = []workspacesdk.MCPToolInfo{}
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListMCPToolsResponse{
Tools: tools,
})
}
// handleCallTool proxies a tool invocation to the appropriate
// MCP server based on the tool name prefix.
func (api *API) handleCallTool(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.CallMCPToolRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
resp, err := api.manager.CallTool(ctx, req)
if err != nil {
status := http.StatusBadGateway
if errors.Is(err, ErrInvalidToolName) {
status = http.StatusBadRequest
} else if errors.Is(err, ErrUnknownServer) {
status = http.StatusNotFound
}
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: "MCP tool call failed.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
+115
View File
@@ -0,0 +1,115 @@
package agentmcp
import (
"encoding/json"
"os"
"slices"
"strings"
"golang.org/x/xerrors"
)
// ServerConfig describes a single MCP server parsed from a .mcp.json file.
type ServerConfig struct {
Name string `json:"name"`
Transport string `json:"type"`
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// mcpConfigFile mirrors the on-disk .mcp.json schema.
type mcpConfigFile struct {
MCPServers map[string]json.RawMessage `json:"mcpServers"`
}
// mcpServerEntry is a single server block inside mcpServers.
type mcpServerEntry struct {
Command string `json:"command"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Type string `json:"type"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
}
// ParseConfig reads a .mcp.json file at path and returns the declared
// MCP servers sorted by name. It returns an empty slice when the
// mcpServers key is missing or empty.
func ParseConfig(path string) ([]ServerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, xerrors.Errorf("read mcp config %q: %w", path, err)
}
var cfg mcpConfigFile
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, xerrors.Errorf("parse mcp config %q: %w", path, err)
}
if len(cfg.MCPServers) == 0 {
return []ServerConfig{}, nil
}
servers := make([]ServerConfig, 0, len(cfg.MCPServers))
for name, raw := range cfg.MCPServers {
var entry mcpServerEntry
if err := json.Unmarshal(raw, &entry); err != nil {
return nil, xerrors.Errorf("parse server %q in %q: %w", name, path, err)
}
if strings.Contains(name, ToolNameSep) || strings.HasPrefix(name, "_") || strings.HasSuffix(name, "_") {
return nil, xerrors.Errorf("server name %q in %q contains reserved separator %q or leading/trailing underscore", name, path, ToolNameSep)
}
transport := inferTransport(entry)
if transport == "" {
return nil, xerrors.Errorf("server %q in %q has no command or url", name, path)
}
resolveEnvVars(entry.Env)
servers = append(servers, ServerConfig{
Name: name,
Transport: transport,
Command: entry.Command,
Args: entry.Args,
Env: entry.Env,
URL: entry.URL,
Headers: entry.Headers,
})
}
slices.SortFunc(servers, func(a, b ServerConfig) int {
return strings.Compare(a.Name, b.Name)
})
return servers, nil
}
// inferTransport determines the transport type for a server entry.
// An explicit "type" field takes priority; otherwise the presence
// of "command" implies stdio and "url" implies http.
func inferTransport(e mcpServerEntry) string {
if e.Type != "" {
return e.Type
}
if e.Command != "" {
return "stdio"
}
if e.URL != "" {
return "http"
}
return ""
}
// resolveEnvVars expands ${VAR} references in env map values
// using the current process environment.
func resolveEnvVars(env map[string]string) {
for k, v := range env {
env[k] = os.Expand(v, os.Getenv)
}
}
+254
View File
@@ -0,0 +1,254 @@
package agentmcp_test
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/x/agentmcp"
)
func TestParseConfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
content string
expected []agentmcp.ServerConfig
expectError bool
}{
{
name: "StdioServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"my-server": map[string]any{
"command": "npx",
"args": []string{"-y", "@example/mcp-server"},
"env": map[string]string{"FOO": "bar"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "my-server",
Transport: "stdio",
Command: "npx",
Args: []string{"-y", "@example/mcp-server"},
Env: map[string]string{"FOO": "bar"},
},
},
},
{
name: "HTTPServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"remote": map[string]any{
"url": "https://example.com/mcp",
"headers": map[string]string{"Authorization": "Bearer tok"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "remote",
Transport: "http",
URL: "https://example.com/mcp",
Headers: map[string]string{"Authorization": "Bearer tok"},
},
},
},
{
name: "SSEServer",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"events": map[string]any{
"type": "sse",
"url": "https://example.com/sse",
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "events",
Transport: "sse",
URL: "https://example.com/sse",
},
},
},
{
name: "ExplicitTypeOverridesInference",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"hybrid": map[string]any{
"command": "some-binary",
"type": "http",
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "hybrid",
Transport: "http",
Command: "some-binary",
},
},
},
{
name: "EnvVarPassthrough",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"srv": map[string]any{
"command": "run",
"env": map[string]string{"PLAIN": "literal-value"},
},
},
}),
expected: []agentmcp.ServerConfig{
{
Name: "srv",
Transport: "stdio",
Command: "run",
Env: map[string]string{"PLAIN": "literal-value"},
},
},
},
{
name: "EmptyMCPServers",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{},
}),
expected: []agentmcp.ServerConfig{},
},
{
name: "MalformedJSON",
content: `{not valid json`,
expectError: true,
},
{
name: "ServerNameContainsSeparator",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"bad__name": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "ServerNameTrailingUnderscore",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"server_": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "ServerNameLeadingUnderscore",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"_server": map[string]any{"command": "run"},
},
}),
expectError: true,
},
{
name: "EmptyTransport", content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"empty": map[string]any{},
},
}),
expectError: true,
},
{
name: "MissingMCPServersKey",
content: mustJSON(t, map[string]any{
"servers": map[string]any{},
}),
expected: []agentmcp.ServerConfig{},
},
{
name: "MultipleServersSortedByName",
content: mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"zeta": map[string]any{"command": "z"},
"alpha": map[string]any{"command": "a"},
"mu": map[string]any{"command": "m"},
},
}),
expected: []agentmcp.ServerConfig{
{Name: "alpha", Transport: "stdio", Command: "a"},
{Name: "mu", Transport: "stdio", Command: "m"},
{Name: "zeta", Transport: "stdio", Command: "z"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, ".mcp.json")
err := os.WriteFile(path, []byte(tt.content), 0o600)
require.NoError(t, err)
got, err := agentmcp.ParseConfig(path)
if tt.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, got)
})
}
}
// TestParseConfig_EnvVarInterpolation verifies that ${VAR} references
// in env values are resolved from the process environment. This test
// cannot be parallel because t.Setenv is incompatible with t.Parallel.
func TestParseConfig_EnvVarInterpolation(t *testing.T) {
t.Setenv("TEST_MCP_TOKEN", "secret123")
content := mustJSON(t, map[string]any{
"mcpServers": map[string]any{
"srv": map[string]any{
"command": "run",
"env": map[string]string{"TOKEN": "${TEST_MCP_TOKEN}"},
},
},
})
dir := t.TempDir()
path := filepath.Join(dir, ".mcp.json")
err := os.WriteFile(path, []byte(content), 0o600)
require.NoError(t, err)
got, err := agentmcp.ParseConfig(path)
require.NoError(t, err)
require.Equal(t, []agentmcp.ServerConfig{
{
Name: "srv",
Transport: "stdio",
Command: "run",
Env: map[string]string{"TOKEN": "secret123"},
},
}, got)
}
func TestParseConfig_FileNotFound(t *testing.T) {
t.Parallel()
_, err := agentmcp.ParseConfig(filepath.Join(t.TempDir(), "nonexistent.json"))
require.Error(t, err)
}
// mustJSON marshals v to a JSON string, failing the test on error.
func mustJSON(t *testing.T, v any) string {
t.Helper()
data, err := json.Marshal(v)
require.NoError(t, err)
return string(data)
}
+447
View File
@@ -0,0 +1,447 @@
package agentmcp
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ToolNameSep separates the server name from the original tool name
// in prefixed tool names. Double underscore avoids collisions with
// tool names that may contain single underscores.
const ToolNameSep = "__"
// connectTimeout bounds how long we wait for a single MCP server
// to start its transport and complete initialization.
const connectTimeout = 30 * time.Second
// toolCallTimeout bounds how long a single tool invocation may
// take before being canceled.
const toolCallTimeout = 60 * time.Second
var (
// ErrInvalidToolName is returned when the tool name format
// is not "server__tool".
ErrInvalidToolName = xerrors.New("invalid tool name format")
// ErrUnknownServer is returned when no MCP server matches
// the prefix in the tool name.
ErrUnknownServer = xerrors.New("unknown MCP server")
)
// Manager manages connections to MCP servers discovered from a
// workspace's .mcp.json file. It caches the aggregated tool list
// and proxies tool calls to the appropriate server.
type Manager struct {
mu sync.RWMutex
logger slog.Logger
closed bool
servers map[string]*serverEntry // keyed by server name
tools []workspacesdk.MCPToolInfo
}
// serverEntry pairs a server config with its connected client.
type serverEntry struct {
config ServerConfig
client *client.Client
}
// NewManager creates a new MCP client manager.
func NewManager(logger slog.Logger) *Manager {
return &Manager{
logger: logger,
servers: make(map[string]*serverEntry),
}
}
// Connect discovers .mcp.json in dir and connects to all
// configured servers. Failed servers are logged and skipped.
func (m *Manager) Connect(ctx context.Context, dir string) error {
path := filepath.Join(dir, ".mcp.json")
configs, err := ParseConfig(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return xerrors.Errorf("parse mcp config: %w", err)
}
// Connect to servers in parallel without holding the
// lock, since each connectServer call may block on
// network I/O for up to connectTimeout.
type connectedServer struct {
name string
config ServerConfig
client *client.Client
}
var (
mu sync.Mutex
connected []connectedServer
)
var eg errgroup.Group
for _, cfg := range configs {
eg.Go(func() error {
c, err := m.connectServer(ctx, cfg)
if err != nil {
m.logger.Warn(ctx, "skipping MCP server",
slog.F("server", cfg.Name),
slog.F("transport", cfg.Transport),
slog.Error(err),
)
return nil // Don't fail the group.
}
mu.Lock()
connected = append(connected, connectedServer{
name: cfg.Name, config: cfg, client: c,
})
mu.Unlock()
return nil
})
}
_ = eg.Wait()
m.mu.Lock()
if m.closed {
m.mu.Unlock()
// Close the freshly-connected clients since we're
// shutting down.
for _, cs := range connected {
_ = cs.client.Close()
}
return xerrors.New("manager closed")
}
// Close previous connections to avoid leaking child
// processes on agent reconnect.
for _, entry := range m.servers {
_ = entry.client.Close()
}
m.servers = make(map[string]*serverEntry, len(connected))
for _, cs := range connected {
m.servers[cs.name] = &serverEntry{
config: cs.config,
client: cs.client,
}
}
m.mu.Unlock()
// Refresh tools outside the lock to avoid blocking
// concurrent reads during network I/O.
if err := m.RefreshTools(ctx); err != nil {
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
}
return nil
}
// connectServer establishes a connection to a single MCP server
// and returns the connected client. It does not modify any Manager
// state.
func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
tr, err := createTransport(cfg)
if err != nil {
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
}
c := client.NewClient(tr)
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
if err := c.Start(connectCtx); err != nil {
_ = c.Close()
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
}
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder-agent",
Version: buildinfo.Version(),
},
},
})
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
}
return c, nil
}
// createTransport builds the mcp-go transport for a server config.
func createTransport(cfg ServerConfig) (transport.Interface, error) {
switch cfg.Transport {
case "stdio":
return transport.NewStdio(
cfg.Command,
buildEnv(cfg.Env),
cfg.Args...,
), nil
case "http", "":
return transport.NewStreamableHTTP(
cfg.URL,
transport.WithHTTPHeaders(cfg.Headers),
)
case "sse":
return transport.NewSSE(
cfg.URL,
transport.WithHeaders(cfg.Headers),
)
default:
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
}
}
// buildEnv merges the current process environment with explicit
// overrides, returning the result as KEY=VALUE strings suitable
// for the stdio transport.
func buildEnv(explicit map[string]string) []string {
env := os.Environ()
if len(explicit) == 0 {
return env
}
// Index existing env so explicit keys can override in-place.
existing := make(map[string]int, len(env))
for i, kv := range env {
if k, _, ok := strings.Cut(kv, "="); ok {
existing[k] = i
}
}
for k, v := range explicit {
entry := k + "=" + v
if idx, ok := existing[k]; ok {
env[idx] = entry
} else {
env = append(env, entry)
}
}
return env
}
// Tools returns the cached tool list. Thread-safe.
func (m *Manager) Tools() []workspacesdk.MCPToolInfo {
m.mu.RLock()
defer m.mu.RUnlock()
return slices.Clone(m.tools)
}
// CallTool proxies a tool call to the appropriate MCP server.
func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
serverName, originalName, err := splitToolName(req.ToolName)
if err != nil {
return workspacesdk.CallMCPToolResponse{}, err
}
m.mu.RLock()
entry, ok := m.servers[serverName]
m.mu.RUnlock()
if !ok {
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("%w: %q", ErrUnknownServer, serverName)
}
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
defer cancel()
result, err := entry.client.CallTool(callCtx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: originalName,
Arguments: req.Arguments,
},
})
if err != nil {
return workspacesdk.CallMCPToolResponse{}, xerrors.Errorf("call tool %q on %q: %w", originalName, serverName, err)
}
return convertResult(result), nil
}
// splitToolName extracts the server name and original tool name
// from a prefixed tool name like "server__tool".
func splitToolName(prefixed string) (serverName, toolName string, err error) {
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
if !ok || server == "" || tool == "" {
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
}
return server, tool, nil
}
// convertResult translates an MCP CallToolResult into a
// workspacesdk.CallMCPToolResponse. It iterates over content
// items and maps each recognized type.
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
if result == nil {
return workspacesdk.CallMCPToolResponse{}
}
var content []workspacesdk.MCPToolContent
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: c.Text,
})
case mcp.ImageContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "image",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.AudioContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "audio",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.EmbeddedResource:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
})
case mcp.ResourceLink:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[resource link: %s]", c.URI),
})
default:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: fmt.Sprintf("[unsupported content type: %T]", item),
})
}
}
return workspacesdk.CallMCPToolResponse{
Content: content,
IsError: result.IsError,
}
}
// RefreshTools re-fetches tool lists from all connected servers
// in parallel and rebuilds the cache. On partial failure, tools
// from servers that responded successfully are merged with the
// existing cached tools for servers that failed, so a single
// dead server doesn't block updates from healthy ones.
func (m *Manager) RefreshTools(ctx context.Context) error {
// Snapshot servers under read lock.
m.mu.RLock()
servers := make(map[string]*serverEntry, len(m.servers))
for k, v := range m.servers {
servers[k] = v
}
m.mu.RUnlock()
// Fetch tool lists in parallel without holding any lock.
type serverTools struct {
name string
tools []workspacesdk.MCPToolInfo
}
var (
mu sync.Mutex
results []serverTools
failed []string
errs []error
)
var eg errgroup.Group
for name, entry := range servers {
eg.Go(func() error {
listCtx, cancel := context.WithTimeout(ctx, connectTimeout)
result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{})
cancel()
if err != nil {
m.logger.Warn(ctx, "failed to list tools from MCP server",
slog.F("server", name),
slog.Error(err),
)
mu.Lock()
errs = append(errs, xerrors.Errorf("list tools from %q: %w", name, err))
failed = append(failed, name)
mu.Unlock()
return nil
}
var tools []workspacesdk.MCPToolInfo
for _, tool := range result.Tools {
tools = append(tools, workspacesdk.MCPToolInfo{
ServerName: name,
Name: name + ToolNameSep + tool.Name,
Description: tool.Description,
Schema: tool.InputSchema.Properties,
Required: tool.InputSchema.Required,
})
}
mu.Lock()
results = append(results, serverTools{name: name, tools: tools})
mu.Unlock()
return nil
})
}
_ = eg.Wait()
// Build the new tool list. For servers that failed, preserve
// their tools from the existing cache so a single dead server
// doesn't remove healthy tools.
var merged []workspacesdk.MCPToolInfo
for _, st := range results {
merged = append(merged, st.tools...)
}
if len(failed) > 0 {
failedSet := make(map[string]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
m.mu.RLock()
for _, t := range m.tools {
if _, ok := failedSet[t.ServerName]; ok {
merged = append(merged, t)
}
}
m.mu.RUnlock()
}
slices.SortFunc(merged, func(a, b workspacesdk.MCPToolInfo) int {
return strings.Compare(a.Name, b.Name)
})
m.mu.Lock()
m.tools = merged
m.mu.Unlock()
return errors.Join(errs...)
}
// Close terminates all MCP server connections and child
// processes.
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.closed = true
var errs []error
for _, entry := range m.servers {
errs = append(errs, entry.client.Close())
}
m.servers = make(map[string]*serverEntry)
m.tools = nil
return errors.Join(errs...)
}
+195
View File
@@ -0,0 +1,195 @@
package agentmcp
import (
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
func TestSplitToolName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantServer string
wantTool string
wantErr bool
}{
{
name: "Valid",
input: "server__tool",
wantServer: "server",
wantTool: "tool",
},
{
name: "ValidWithUnderscoresInTool",
input: "server__my_tool",
wantServer: "server",
wantTool: "my_tool",
},
{
name: "MissingSeparator",
input: "servertool",
wantErr: true,
},
{
name: "EmptyServer",
input: "__tool",
wantErr: true,
},
{
name: "EmptyTool",
input: "server__",
wantErr: true,
},
{
name: "JustSeparator",
input: "__",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server, tool, err := splitToolName(tt.input)
if tt.wantErr {
require.Error(t, err)
assert.ErrorIs(t, err, ErrInvalidToolName)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantServer, server)
assert.Equal(t, tt.wantTool, tool)
})
}
}
func TestConvertResult(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// input is a pointer so we can test nil.
input *mcp.CallToolResult
want workspacesdk.CallMCPToolResponse
}{
{
name: "NilInput",
input: nil,
want: workspacesdk.CallMCPToolResponse{},
},
{
name: "TextContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "hello"},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "hello"},
},
},
},
{
name: "ImageContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.ImageContent{
Type: "image",
Data: "base64data",
MIMEType: "image/png",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "image", Data: "base64data", MediaType: "image/png"},
},
},
},
{
name: "AudioContent",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.AudioContent{
Type: "audio",
Data: "base64audio",
MIMEType: "audio/mp3",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "audio", Data: "base64audio", MediaType: "audio/mp3"},
},
},
},
{
name: "IsErrorPropagation",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "fail"},
},
IsError: true,
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "fail"},
},
IsError: true,
},
},
{
name: "MultipleContentItems",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "caption"},
mcp.ImageContent{
Type: "image",
Data: "imgdata",
MIMEType: "image/jpeg",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "caption"},
{Type: "image", Data: "imgdata", MediaType: "image/jpeg"},
},
},
},
{
name: "ResourceLink",
input: &mcp.CallToolResult{
Content: []mcp.Content{
mcp.ResourceLink{
Type: "resource_link",
URI: "file:///tmp/test.txt",
},
},
},
want: workspacesdk.CallMCPToolResponse{
Content: []workspacesdk.MCPToolContent{
{Type: "resource", Text: "[resource link: file:///tmp/test.txt]"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := convertResult(tt.input)
assert.Equal(t, tt.want, got)
})
}
}