diff --git a/agent/agent.go b/agent/agent.go index 16600387d8..e2698e3304 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -423,14 +423,14 @@ func (a *agent) init() { a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil, ) a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock) - a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp")) - a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager) + a.mcpManager = agentmcp.NewManager(a.gracefulCtx, a.logger.Named("mcp"), a.execer, a.updateCommandEnv) a.contextConfigAPI = agentcontextconfig.NewAPI(func() string { if m := a.manifest.Load(); m != nil { return m.Directory } return "" }, a.contextConfig) + a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager, a.contextConfigAPI.MCPConfigFiles) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), a.sshServer, @@ -1413,8 +1413,8 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, // lifecycle transition to avoid delaying Ready. // This runs inside the tracked goroutine so it // is properly awaited on shutdown. - if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil { - a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr)) + if mcpErr := a.mcpManager.Reload(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil { + a.logger.Warn(ctx, "failed to reload workspace MCP servers", slog.Error(mcpErr)) } }) if err != nil { diff --git a/agent/x/agentmcp/api.go b/agent/x/agentmcp/api.go index 8582f68f00..9b632f8b9b 100644 --- a/agent/x/agentmcp/api.go +++ b/agent/x/agentmcp/api.go @@ -1,6 +1,7 @@ package agentmcp import ( + "context" "errors" "net/http" @@ -15,16 +16,24 @@ import ( // API exposes MCP tool discovery and call proxying through the // agent. type API struct { - logger slog.Logger - manager *Manager + logger slog.Logger + manager *Manager + mcpConfigFiles func() []string } // NewAPI creates a new MCP API handler backed by the given -// manager. -func NewAPI(logger slog.Logger, manager *Manager) *API { +// manager. The mcpConfigFiles callback returns the current +// resolved config file paths; it is called on every tool-list +// request to detect config changes. +func NewAPI( + logger slog.Logger, + manager *Manager, + mcpConfigFiles func() []string, +) *API { return &API{ - logger: logger, - manager: manager, + logger: logger, + manager: manager, + mcpConfigFiles: mcpConfigFiles, } } @@ -36,13 +45,38 @@ func (api *API) Routes() http.Handler { return r } -// handleListTools returns the cached MCP tool definitions, -// optionally refreshing them first if ?refresh=true is set. +// handleListTools checks whether any .mcp.json config file +// has changed since the last reload, triggering a differential +// reload if so, then returns the cached MCP tool definitions. +// The ?refresh=true query parameter forces a tool re-scan +// independent of config changes. func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + // Check config freshness and reload if changed. + var reloaded bool + paths := api.mcpConfigFiles() + if api.manager.SnapshotChanged(paths) { + if err := api.manager.Reload(ctx, paths); err != nil { + // Categorize the error for operator debugging. + switch { + case errors.Is(err, context.Canceled): + api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err)) + case errors.Is(err, context.DeadlineExceeded): + api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err)) + default: + api.logger.Warn(ctx, "mcp reload failed", slog.Error(err)) + } + // Fall through to return whatever tools we have. + } else { + reloaded = true + } + } + // Allow callers to force a tool re-scan before listing. - if r.URL.Query().Get("refresh") == "true" { + // Skip if a config reload ran above, since it already + // refreshes tools as part of the reload. + if r.URL.Query().Get("refresh") == "true" && !reloaded { if err := api.manager.RefreshTools(ctx); err != nil { api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) } diff --git a/agent/x/agentmcp/api_internal_test.go b/agent/x/agentmcp/api_internal_test.go new file mode 100644 index 0000000000..a2135204ef --- /dev/null +++ b/agent/x/agentmcp/api_internal_test.go @@ -0,0 +1,228 @@ +package agentmcp + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentexec" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" +) + +func TestHandleListTools_ReloadOnChange(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + // Cases that share the single-request-and-check pattern. + type singleRequestCase struct { + name string + entries func(t *testing.T) map[string]mcpServerEntry + reloadManager bool + closeManager bool + expectedTools int + toolNameContains string + } + + cases := []singleRequestCase{ + { + name: "InitialRequestNoReload", + entries: func(t *testing.T) map[string]mcpServerEntry { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + return map[string]mcpServerEntry{"srv": entry} + }, + reloadManager: true, + expectedTools: 1, + toolNameContains: "echo", + }, + { + name: "ManagerClosedReturnsEmpty", + entries: func(_ *testing.T) map[string]mcpServerEntry { + return map[string]mcpServerEntry{} + }, + closeManager: true, + expectedTools: 0, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + configPath := writeMCPConfig(t, dir, tc.entries(t)) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + if tc.closeManager { + require.NoError(t, m.Close()) + } else { + t.Cleanup(func() { _ = m.Close() }) + } + + if tc.reloadManager { + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + } + + api := NewAPI(logger, m, func() []string { + return []string{configPath} + }) + + req := httptest.NewRequest(http.MethodGet, "/tools", nil) + rec := httptest.NewRecorder() + api.Routes().ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp workspacesdk.ListMCPToolsResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + require.Len(t, resp.Tools, tc.expectedTools) + if tc.toolNameContains != "" { + assert.Contains(t, resp.Tools[0].Name, tc.toolNameContains) + } + }) + } + + // ConfigChangeTriggersReload has a mutate-then-re-request flow + // that does not fit the single-request table pattern. + t.Run("ConfigChangeTriggersReload", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry1 := fakeMCPServerConfig(t, "srv1") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + api := NewAPI(logger, m, func() []string { + return []string{configPath} + }) + + // Verify initial tools. + req := httptest.NewRequest(http.MethodGet, "/tools", nil) + rec := httptest.NewRecorder() + api.Routes().ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp1 workspacesdk.ListMCPToolsResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp1)) + require.Len(t, resp1.Tools, 1) + assert.Contains(t, resp1.Tools[0].Name, "srv1") + + // Mutate the config file. + _, entry2 := fakeMCPServerConfig(t, "srv2") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2}) + + // Next request should trigger a reload and return new tools. + req2 := httptest.NewRequest(http.MethodGet, "/tools", nil) + rec2 := httptest.NewRecorder() + api.Routes().ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + + var resp2 workspacesdk.ListMCPToolsResponse + require.NoError(t, json.NewDecoder(rec2.Body).Decode(&resp2)) + require.Len(t, resp2.Tools, 1) + assert.Contains(t, resp2.Tools[0].Name, "srv2") + }) +} + +func TestHandleListTools_RefreshParam(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + t.Run("RefreshTrueUnchangedSnapshot", func(t *testing.T) { + // Exercises the ?refresh=true code path when the config + // snapshot is unchanged. Verifies the endpoint returns + // tools without error. + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + api := NewAPI(logger, m, func() []string { + return []string{configPath} + }) + + req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil) + rec := httptest.NewRecorder() + api.Routes().ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp workspacesdk.ListMCPToolsResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + // Tool should still be present after refresh. + require.Len(t, resp.Tools, 1) + assert.Contains(t, resp.Tools[0].Name, "echo") + }) + + t.Run("RefreshTrueWithChangedConfig", func(t *testing.T) { + // Exercises the ?refresh=true code path when the config + // has also changed. The reload path already calls + // RefreshTools, so the handler skips the redundant call. + // This test covers the branch; it cannot observe the + // skip without a mock. + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry1 := fakeMCPServerConfig(t, "srv1") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + api := NewAPI(logger, m, func() []string { + return []string{configPath} + }) + + // Mutate config. + _, entry2 := fakeMCPServerConfig(t, "srv2") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2}) + + req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil) + rec := httptest.NewRecorder() + api.Routes().ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp workspacesdk.ListMCPToolsResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + require.Len(t, resp.Tools, 1) + assert.Contains(t, resp.Tools[0].Name, "srv2") + }) +} diff --git a/agent/x/agentmcp/manager.go b/agent/x/agentmcp/manager.go index 66662cd3d3..94fc1bf0e3 100644 --- a/agent/x/agentmcp/manager.go +++ b/agent/x/agentmcp/manager.go @@ -5,7 +5,10 @@ import ( "errors" "fmt" "io/fs" + "maps" "os" + "os/exec" + "reflect" "slices" "strings" "sync" @@ -16,8 +19,11 @@ import ( "github.com/mark3labs/mcp-go/mcp" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + tailscalesingleflight "tailscale.com/util/singleflight" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentexec" + "github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk/workspacesdk" ) @@ -44,15 +50,30 @@ var ( ErrUnknownServer = xerrors.New("unknown MCP server") ) +// fileSnapshot records the identity of a config file at the time +// it was last read. +type fileSnapshot struct { + exists bool + modTime time.Time + size int64 +} + // Manager manages connections to MCP servers discovered from a // workspace's .mcp.json file. It caches the aggregated tool list // and proxies tool calls to the appropriate server. type Manager struct { - mu sync.RWMutex - logger slog.Logger - closed bool - servers map[string]*serverEntry // keyed by server name - tools []workspacesdk.MCPToolInfo + ctx context.Context + execer agentexec.Execer + updateEnv func(current []string) ([]string, error) + + mu sync.RWMutex + logger slog.Logger + closed bool + servers map[string]*serverEntry + tools []workspacesdk.MCPToolInfo + snapshot map[string]fileSnapshot + serverGen uint64 + sf tailscalesingleflight.Group[string, struct{}] } // serverEntry pairs a server config with its connected client. @@ -61,18 +82,189 @@ type serverEntry struct { client *client.Client } -// NewManager creates a new MCP client manager. -func NewManager(logger slog.Logger) *Manager { +// NewManager creates a new MCP client manager. The ctx bounds +// subprocess lifetime. The execer applies resource limits to +// MCP server subprocesses. The updateEnv callback enriches the +// subprocess environment to match interactive sessions. +func NewManager( + ctx context.Context, + logger slog.Logger, + execer agentexec.Execer, + updateEnv func([]string) ([]string, error), +) *Manager { return &Manager{ - logger: logger, - servers: make(map[string]*serverEntry), + ctx: ctx, + logger: logger, + execer: execer, + updateEnv: updateEnv, + servers: make(map[string]*serverEntry), + snapshot: make(map[string]fileSnapshot), } } -// Connect reads MCP config files at the given absolute paths and -// connects to all configured servers. Failed servers are logged -// and skipped. Missing config files are silently skipped. -func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error { +// Reload checks whether config files have changed and, if so, +// performs a differential reconnect. Concurrent callers are +// coalesced via singleflight; the reload body runs under the +// Manager's lifetime context so it survives caller cancellation. +func (m *Manager) Reload(ctx context.Context, paths []string) error { + m.mu.RLock() + closed := m.closed + hasSnapshot := len(m.snapshot) > 0 + m.mu.RUnlock() + if closed { + return xerrors.New("manager closed") + } + + // Double-check: another goroutine may have completed a + // reload between the caller's SnapshotChanged and this + // call. The singleflight body uses its own resolved paths. + if hasSnapshot && !m.SnapshotChanged(paths) { + return nil + } + + // All concurrent callers share one in-flight reload keyed + // by "". If a concurrent caller resolves different paths + // (e.g. after a manifest reconnect), its paths are not + // consulted; the next SnapshotChanged check after this + // reload completes will detect the mismatch and trigger + // a fresh reload. + ch := m.sf.DoChan("reload", func() (struct{}, error) { + err := m.doReload(m.ctx, paths) + return struct{}{}, err + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case res := <-ch: + return res.Err + } +} + +// SnapshotChanged checks whether any config file has changed +// since the last reload by comparing os.Stat results against +// the stored snapshot. +func (m *Manager) SnapshotChanged(paths []string) bool { + seen := make(map[string]struct{}, len(paths)) + unique := make([]string, 0, len(paths)) + for _, p := range paths { + if _, ok := seen[p]; !ok { + seen[p] = struct{}{} + unique = append(unique, p) + } + } + paths = unique + + m.mu.RLock() + snap := maps.Clone(m.snapshot) + snapshotLen := len(snap) + m.mu.RUnlock() + + if len(paths) != snapshotLen { + return true + } + + for _, p := range paths { + prev, ok := snap[p] + if !ok { + return true + } + + info, err := os.Stat(p) + if err != nil { + // Stat failed; changed only if the file existed before. + if prev.exists { + return true + } + continue + } + + // Stat succeeded but file was absent before: it appeared. + if !prev.exists { + return true + } + + if !info.ModTime().Equal(prev.modTime) || info.Size() != prev.size { + return true + } + } + + return false +} + +// serverDiff is the output of classifyServers: which servers to +// connect, which to close, which to keep, and a snapshot of the +// previous map for fallback on connect failure. +type serverDiff struct { + toConnect []ServerConfig + toClose []*serverEntry + keep map[string]*serverEntry + prev map[string]*serverEntry +} + +type connectedServer struct { + name string + config ServerConfig + client *client.Client +} + +// doReload reads MCP config files and performs a differential +// reconnect. Unchanged servers keep their existing client; new or +// changed servers get a fresh connection; removed servers are +// closed. +func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error { + allConfigs, snap := m.parseAndDedup(ctx, mcpConfigFiles) + + wanted := make(map[string]ServerConfig, len(allConfigs)) + for _, cfg := range allConfigs { + wanted[cfg.Name] = cfg + } + + diff, err := m.classifyServers(wanted) + if err != nil { + return err + } + + connected := m.connectAll(ctx, diff.toConnect) + + replaced, err := m.installServers(wanted, diff, connected, snap) + if err != nil { + return err + } + + // Close removed and replaced servers outside the lock to + // avoid leaking child processes and to avoid blocking + // concurrent readers on subprocess I/O. + // Note: a concurrent CallTool that captured a removed + // entry's client before the swap may call a closed client. + // This is a narrow race that self-heals on the next request. + for _, entry := range diff.toClose { + _ = entry.client.Close() + } + for _, entry := range replaced { + _ = entry.client.Close() + } + + // Refresh tools outside the lock to avoid blocking + // concurrent reads during network I/O. + if err := m.RefreshTools(ctx); err != nil { + m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) + } + return nil +} + +// parseAndDedup reads all config files and returns a deduplicated +// list of server configs. Missing files are silently skipped; +// parse errors are logged and skipped. +func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) { + // Stat before reading so the snapshot is conservatively old. + // If a file changes between stat and read, the snapshot + // records the old mtime, SnapshotChanged detects a mismatch + // on the next check, and triggers a re-read. False positives + // (extra reload) are safe; false negatives (missed change) + // are not. + snap := captureSnapshot(mcpConfigFiles) + var allConfigs []ServerConfig for _, configPath := range mcpConfigFiles { configs, err := ParseConfig(configPath) @@ -99,26 +291,55 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error { seen[cfg.Name] = struct{}{} deduped = append(deduped, cfg) } - allConfigs = deduped + return deduped, snap +} - if len(allConfigs) == 0 { - return nil +// classifyServers compares wanted configs against the current +// server map and returns a diff describing what changed. +// Acquires and releases m.mu for reading. +func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.closed { + return nil, xerrors.New("manager closed") } - // Connect to servers in parallel without holding the - // lock, since each connectServer call may block on - // network I/O for up to connectTimeout. - type connectedServer struct { - name string - config ServerConfig - client *client.Client + diff := &serverDiff{ + keep: make(map[string]*serverEntry), } + + for name, wantCfg := range wanted { + if existing, ok := m.servers[name]; ok { + if reflect.DeepEqual(existing.config, wantCfg) { + diff.keep[name] = existing + } else { + diff.toConnect = append(diff.toConnect, wantCfg) + } + } else { + diff.toConnect = append(diff.toConnect, wantCfg) + } + } + + for name, entry := range m.servers { + if _, ok := wanted[name]; !ok { + diff.toClose = append(diff.toClose, entry) + } + } + + diff.prev = maps.Clone(m.servers) + return diff, nil +} + +// connectAll runs connectServer in parallel for the given configs. +// Failed connects are logged and skipped. +func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer { var ( mu sync.Mutex connected []connectedServer ) var eg errgroup.Group - for _, cfg := range allConfigs { + for _, cfg := range toConnect { eg.Go(func() error { c, err := m.connectServer(ctx, cfg) if err != nil { @@ -138,131 +359,81 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error { }) } _ = eg.Wait() + return connected +} +// installServers builds the new server map from diff.keep and the +// connected list, falling back to diff.prev when a connect failed. +// Returns old entries replaced by successful connects (caller +// closes them). Acquires and releases m.mu. +func (m *Manager) installServers( + wanted map[string]ServerConfig, + diff *serverDiff, + connected []connectedServer, + snap map[string]fileSnapshot, +) ([]*serverEntry, error) { m.mu.Lock() + defer m.mu.Unlock() + if m.closed { - m.mu.Unlock() - // Close the freshly-connected clients since we're - // shutting down. for _, cs := range connected { _ = cs.client.Close() } - return xerrors.New("manager closed") + return nil, xerrors.New("manager closed") } - // Close previous connections to avoid leaking child - // processes on agent reconnect. - for _, entry := range m.servers { - _ = entry.client.Close() - } - m.servers = make(map[string]*serverEntry, len(connected)) - + newConnected := make(map[string]connectedServer, len(connected)) for _, cs := range connected { - m.servers[cs.name] = &serverEntry{ - config: cs.config, - client: cs.client, + newConnected[cs.name] = cs + } + + newServers := make(map[string]*serverEntry, len(wanted)) + for name, entry := range diff.keep { + newServers[name] = entry + } + + var replaced []*serverEntry + for name, wantCfg := range wanted { + if _, kept := diff.keep[name]; kept { + continue } - } - m.mu.Unlock() - - // Refresh tools outside the lock to avoid blocking - // concurrent reads during network I/O. - if err := m.RefreshTools(ctx); err != nil { - m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) - } - return nil -} - -// connectServer establishes a connection to a single MCP server -// and returns the connected client. It does not modify any Manager -// state. -func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) { - tr, err := createTransport(cfg) - if err != nil { - return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err) - } - - c := client.NewClient(tr) - - connectCtx, cancel := context.WithTimeout(ctx, connectTimeout) - defer cancel() - - // Use the parent ctx (not connectCtx) so the subprocess outlives - // the connect/initialize handshake. connectCtx bounds only the - // Initialize call below. The subprocess is cleaned up when the - // Manager is closed or ctx is canceled. - if err := c.Start(ctx); err != nil { - _ = c.Close() - return nil, xerrors.Errorf("start %q: %w", cfg.Name, err) - } - - _, err = c.Initialize(connectCtx, mcp.InitializeRequest{ - Params: mcp.InitializeParams{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - ClientInfo: mcp.Implementation{ - Name: "coder-agent", - Version: buildinfo.Version(), - }, - }, - }) - if err != nil { - _ = c.Close() - return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err) - } - - return c, nil -} - -// createTransport builds the mcp-go transport for a server config. -func createTransport(cfg ServerConfig) (transport.Interface, error) { - switch cfg.Transport { - case "stdio": - return transport.NewStdio( - cfg.Command, - buildEnv(cfg.Env), - cfg.Args..., - ), nil - case "http", "": - return transport.NewStreamableHTTP( - cfg.URL, - transport.WithHTTPHeaders(cfg.Headers), - ) - case "sse": - return transport.NewSSE( - cfg.URL, - transport.WithHeaders(cfg.Headers), - ) - default: - return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport) - } -} - -// buildEnv merges the current process environment with explicit -// overrides, returning the result as KEY=VALUE strings suitable -// for the stdio transport. -func buildEnv(explicit map[string]string) []string { - env := os.Environ() - if len(explicit) == 0 { - return env - } - - // Index existing env so explicit keys can override in-place. - existing := make(map[string]int, len(env)) - for i, kv := range env { - if k, _, ok := strings.Cut(kv, "="); ok { - existing[k] = i + if cs, ok := newConnected[wantCfg.Name]; ok { + newServers[wantCfg.Name] = &serverEntry{ + config: cs.config, + client: cs.client, + } + if prev, existed := diff.prev[wantCfg.Name]; existed { + replaced = append(replaced, prev) + } + } else if prev, existed := diff.prev[wantCfg.Name]; existed { + // Connect failed; retain the old client. + newServers[wantCfg.Name] = prev } } - for k, v := range explicit { - entry := k + "=" + v - if idx, ok := existing[k]; ok { - env[idx] = entry - } else { - env = append(env, entry) + m.servers = newServers + m.serverGen++ + m.snapshot = snap + return replaced, nil +} + +// captureSnapshot stats each path and returns the current +// snapshot map. +func captureSnapshot(paths []string) map[string]fileSnapshot { + snap := make(map[string]fileSnapshot, len(paths)) + for _, p := range paths { + info, err := os.Stat(p) + if err != nil { + snap[p] = fileSnapshot{exists: false} + continue + } + snap[p] = fileSnapshot{ + exists: true, + modTime: info.ModTime(), + size: info.Size(), } } - return env + return snap } // Tools returns the cached tool list. Thread-safe. @@ -304,68 +475,6 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ return convertResult(result), nil } -// splitToolName extracts the server name and original tool name -// from a prefixed tool name like "server__tool". -func splitToolName(prefixed string) (serverName, toolName string, err error) { - server, tool, ok := strings.Cut(prefixed, ToolNameSep) - if !ok || server == "" || tool == "" { - return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed) - } - return server, tool, nil -} - -// convertResult translates an MCP CallToolResult into a -// workspacesdk.CallMCPToolResponse. It iterates over content -// items and maps each recognized type. -func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse { - if result == nil { - return workspacesdk.CallMCPToolResponse{} - } - - var content []workspacesdk.MCPToolContent - for _, item := range result.Content { - switch c := item.(type) { - case mcp.TextContent: - content = append(content, workspacesdk.MCPToolContent{ - Type: "text", - Text: c.Text, - }) - case mcp.ImageContent: - content = append(content, workspacesdk.MCPToolContent{ - Type: "image", - Data: c.Data, - MediaType: c.MIMEType, - }) - case mcp.AudioContent: - content = append(content, workspacesdk.MCPToolContent{ - Type: "audio", - Data: c.Data, - MediaType: c.MIMEType, - }) - case mcp.EmbeddedResource: - content = append(content, workspacesdk.MCPToolContent{ - Type: "resource", - Text: fmt.Sprintf("[embedded resource: %T]", c.Resource), - }) - case mcp.ResourceLink: - content = append(content, workspacesdk.MCPToolContent{ - Type: "resource", - Text: fmt.Sprintf("[resource link: %s]", c.URI), - }) - default: - content = append(content, workspacesdk.MCPToolContent{ - Type: "text", - Text: fmt.Sprintf("[unsupported content type: %T]", item), - }) - } - } - - return workspacesdk.CallMCPToolResponse{ - Content: content, - IsError: result.IsError, - } -} - // RefreshTools re-fetches tool lists from all connected servers // in parallel and rebuilds the cache. On partial failure, tools // from servers that responded successfully are merged with the @@ -378,6 +487,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error { for k, v := range m.servers { servers[k] = v } + gen := m.serverGen m.mu.RUnlock() // Fetch tool lists in parallel without holding any lock. @@ -451,7 +561,12 @@ func (m *Manager) RefreshTools(ctx context.Context) error { }) m.mu.Lock() - m.tools = merged + // Skip the write if the server map changed since the + // snapshot. A doReload that bumped the generation will + // produce a correct tool list; this write would be stale. + if m.serverGen == gen { + m.tools = merged + } m.mu.Unlock() return errors.Join(errs...) @@ -466,9 +581,187 @@ func (m *Manager) Close() error { m.closed = true var errs []error for _, entry := range m.servers { - errs = append(errs, entry.client.Close()) + if err := entry.client.Close(); err != nil { + // Subprocess kill signals are expected during shutdown. + // The stdio transport returns cmd.Wait() which surfaces + // "signal: killed" as an exec.ExitError. + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + errs = append(errs, err) + } + } } m.servers = make(map[string]*serverEntry) m.tools = nil return errors.Join(errs...) } + +// connectServer establishes a connection to a single MCP server +// and returns the connected client. It does not modify any Manager +// state. +func (m *Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) { + tr, err := m.createTransport(ctx, cfg) + if err != nil { + return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err) + } + + c := client.NewClient(tr) + + connectCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + // Use the parent ctx (not connectCtx) so the subprocess outlives + // the connect/initialize handshake. connectCtx bounds only the + // Initialize call below. The subprocess is cleaned up when the + // Manager is closed or ctx is canceled. + if err := c.Start(ctx); err != nil { + _ = c.Close() + return nil, xerrors.Errorf("start %q: %w", cfg.Name, err) + } + + _, err = c.Initialize(connectCtx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "coder-agent", + Version: buildinfo.Version(), + }, + }, + }) + if err != nil { + _ = c.Close() + return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err) + } + + return c, nil +} + +// createTransport builds the mcp-go transport for a server config. +func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transport.Interface, error) { + switch cfg.Transport { + case "stdio": + env := m.buildEnv(ctx, cfg.Env) + return transport.NewStdioWithOptions( + cfg.Command, + env, + cfg.Args, + transport.WithCommandFunc(func(ctx context.Context, command string, cmdEnv []string, args []string) (*exec.Cmd, error) { + cmd := m.execer.CommandContext(ctx, command, args...) + cmd.Env = cmdEnv + return cmd, nil + }), + ), nil + case "http", "": + return transport.NewStreamableHTTP( + cfg.URL, + transport.WithHTTPHeaders(cfg.Headers), + ) + case "sse": + return transport.NewSSE( + cfg.URL, + transport.WithHeaders(cfg.Headers), + ) + default: + return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport) + } +} + +// buildEnv enriches the process environment via the agent's +// updateEnv callback, then merges explicit overrides from the +// server config on top. +func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string { + env := usershell.SystemEnvInfo{}.Environ() + if m.updateEnv != nil { + var err error + env, err = m.updateEnv(env) + if err != nil { + m.logger.Warn(ctx, "failed to enrich MCP server environment", + slog.Error(err), + ) + env = usershell.SystemEnvInfo{}.Environ() + } + } + if len(explicit) == 0 { + return env + } + + // Index existing env so explicit keys can override in-place. + existing := make(map[string]int, len(env)) + for i, kv := range env { + if k, _, ok := strings.Cut(kv, "="); ok { + existing[k] = i + } + } + + for k, v := range explicit { + entry := k + "=" + v + if idx, ok := existing[k]; ok { + env[idx] = entry + } else { + env = append(env, entry) + } + } + return env +} + +// splitToolName extracts the server name and original tool name +// from a prefixed tool name like "server__tool". +func splitToolName(prefixed string) (serverName, toolName string, err error) { + server, tool, ok := strings.Cut(prefixed, ToolNameSep) + if !ok || server == "" || tool == "" { + return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed) + } + return server, tool, nil +} + +// convertResult translates an MCP CallToolResult into a +// workspacesdk.CallMCPToolResponse. It iterates over content +// items and maps each recognized type. +func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse { + if result == nil { + return workspacesdk.CallMCPToolResponse{} + } + + var content []workspacesdk.MCPToolContent + for _, item := range result.Content { + switch c := item.(type) { + case mcp.TextContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "text", + Text: c.Text, + }) + case mcp.ImageContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "image", + Data: c.Data, + MediaType: c.MIMEType, + }) + case mcp.AudioContent: + content = append(content, workspacesdk.MCPToolContent{ + Type: "audio", + Data: c.Data, + MediaType: c.MIMEType, + }) + case mcp.EmbeddedResource: + content = append(content, workspacesdk.MCPToolContent{ + Type: "resource", + Text: fmt.Sprintf("[embedded resource: %T]", c.Resource), + }) + case mcp.ResourceLink: + content = append(content, workspacesdk.MCPToolContent{ + Type: "resource", + Text: fmt.Sprintf("[resource link: %s]", c.URI), + }) + default: + content = append(content, workspacesdk.MCPToolContent{ + Type: "text", + Text: fmt.Sprintf("[unsupported content type: %T]", item), + }) + } + } + + return workspacesdk.CallMCPToolResponse{ + Content: content, + IsError: result.IsError, + } +} diff --git a/agent/x/agentmcp/manager_internal_test.go b/agent/x/agentmcp/manager_internal_test.go index 9510a3affe..7dbfb00a63 100644 --- a/agent/x/agentmcp/manager_internal_test.go +++ b/agent/x/agentmcp/manager_internal_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/testutil" ) @@ -227,7 +228,7 @@ func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) { } ctx := testutil.Context(t, testutil.WaitLong) - m := &Manager{} + m := &Manager{execer: agentexec.DefaultExecer} client, err := m.connectServer(ctx, cfg) require.NoError(t, err, "connectServer should succeed") t.Cleanup(func() { _ = client.Close() }) diff --git a/agent/x/agentmcp/reload_internal_test.go b/agent/x/agentmcp/reload_internal_test.go new file mode 100644 index 0000000000..0f9c903323 --- /dev/null +++ b/agent/x/agentmcp/reload_internal_test.go @@ -0,0 +1,708 @@ +package agentmcp + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentexec" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" +) + +// writeMCPConfig writes a .mcp.json file with the given server +// entries. Each entry maps a server name to its config. +func writeMCPConfig(t *testing.T, dir string, servers map[string]mcpServerEntry) string { + t.Helper() + path := filepath.Join(dir, ".mcp.json") + cfg := mcpConfigFile{MCPServers: make(map[string]json.RawMessage)} + for name, entry := range servers { + raw, err := json.Marshal(entry) + require.NoError(t, err) + cfg.MCPServers[name] = raw + } + data, err := json.Marshal(cfg) + require.NoError(t, err) + err = os.WriteFile(path, data, 0o600) + require.NoError(t, err) + return path +} + +// fakeMCPServerConfig returns a ServerConfig that launches a fake +// MCP server using the test binary re-exec pattern. +func fakeMCPServerConfig(t *testing.T, name string) (ServerConfig, mcpServerEntry) { + t.Helper() + testBin, err := os.Executable() + require.NoError(t, err) + cfg := ServerConfig{ + Name: name, + Transport: "stdio", + Command: testBin, + Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"}, + Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"}, + } + entry := mcpServerEntry{ + Command: testBin, + Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"}, + Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"}, + } + return cfg, entry +} + +func TestSnapshotChanged(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + setup func(t *testing.T, dir string) []string + mutate func(t *testing.T, dir string) + checkPaths func(t *testing.T, dir string, initialPaths []string) []string + want bool + } + + cases := []testCase{ + { + name: "UnchangedFiles", + setup: func(t *testing.T, dir string) []string { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + return []string{configPath} + }, + want: false, + }, + { + name: "ContentChange", + setup: func(t *testing.T, dir string) []string { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + return []string{configPath} + }, + mutate: func(t *testing.T, dir string) { + t.Helper() + _, entry2 := fakeMCPServerConfig(t, "srv2") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2}) + }, + want: true, + }, + { + name: "FileBecomesMissing", + setup: func(t *testing.T, dir string) []string { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + return []string{configPath} + }, + mutate: func(t *testing.T, dir string) { + t.Helper() + require.NoError(t, os.Remove(filepath.Join(dir, ".mcp.json"))) + }, + want: true, + }, + { + name: "FileAppears", + setup: func(t *testing.T, dir string) []string { + t.Helper() + return []string{filepath.Join(dir, ".mcp.json")} + }, + mutate: func(t *testing.T, dir string) { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + }, + want: true, + }, + { + name: "BothAbsentUnchanged", + setup: func(t *testing.T, dir string) []string { + t.Helper() + return []string{filepath.Join(dir, ".mcp.json")} + }, + want: false, + }, + { + name: "PathSetDiffers", + setup: func(t *testing.T, dir string) []string { + t.Helper() + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + return []string{configPath} + }, + checkPaths: func(t *testing.T, dir string, initialPaths []string) []string { + t.Helper() + extraPath := filepath.Join(dir, "extra.mcp.json") + return append(initialPaths, extraPath) + }, + want: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + paths := tc.setup(t, dir) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, paths) + require.NoError(t, err) + + if tc.mutate != nil { + tc.mutate(t, dir) + } + + checkPaths := paths + if tc.checkPaths != nil { + checkPaths = tc.checkPaths(t, dir, paths) + } + + changed := m.SnapshotChanged(checkPaths) + assert.Equal(t, tc.want, changed) + }) + } +} + +func TestSnapshotChanged_MultipleConfigFiles(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + dir1 := t.TempDir() + dir2 := t.TempDir() + + _, entry1 := fakeMCPServerConfig(t, "srv1") + _, entry2 := fakeMCPServerConfig(t, "srv2") + path1 := writeMCPConfig(t, dir1, map[string]mcpServerEntry{"srv1": entry1}) + path2 := writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2": entry2}) + paths := []string{path1, path2} + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // Initial reload with both config files. + err := m.Reload(ctx, paths) + require.NoError(t, err) + + // Both files unchanged. + assert.False(t, m.SnapshotChanged(paths), + "snapshot should not change when both files are unchanged") + + // Mutate only the second file. + _, entry2b := fakeMCPServerConfig(t, "srv2b") + writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2b": entry2b}) + + assert.True(t, m.SnapshotChanged(paths), + "snapshot should change when second file is mutated") + + // Reload picks up the mutation. + err = m.Reload(ctx, paths) + require.NoError(t, err) + + // Tools from both files should be present. + tools := m.Tools() + require.Len(t, tools, 2, "should have tools from both config files") + assert.Contains(t, tools[0].Name, "srv1", + "first tool should be from first config") + assert.Contains(t, tools[1].Name, "srv2b", + "second tool should be from second config") +} + +func TestReload(t *testing.T) { + t.Parallel() + + t.Run("SingleReloadUpdatesSnapshot", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + tools := m.Tools() + require.Len(t, tools, 1, "should have one tool from the fake server") + assert.Contains(t, tools[0].Name, "echo") + + // Snapshot should be fresh. + assert.False(t, m.SnapshotChanged([]string{configPath})) + }) + + t.Run("ReloadAfterClose", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + require.NoError(t, m.Close()) + + err := m.Reload(ctx, []string{"/nonexistent"}) + require.Error(t, err, "reload after close should fail") + }) + + t.Run("ConcurrentReloadsCoalesce", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // Launch multiple concurrent reloads. + const numCallers = 5 + var wg sync.WaitGroup + errs := make([]error, numCallers) + for i := range numCallers { + wg.Go(func() { + errs[i] = m.Reload(ctx, []string{configPath}) + }) + } + wg.Wait() + + for i, err := range errs { + assert.NoError(t, err, "caller %d should not fail", i) + } + + tools := m.Tools() + require.Len(t, tools, 1) + }) + + t.Run("CallerContextCanceled", func(t *testing.T) { + t.Parallel() + mgrCtx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(mgrCtx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // Use an already-canceled caller context. + callerCtx, cancel := context.WithCancel(mgrCtx) + cancel() // Cancel immediately. + + err := m.Reload(callerCtx, []string{configPath}) + // The caller context is already canceled, so Reload should + // return the caller's context error. + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("SequentialReloadsDiffDetect", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry1 := fakeMCPServerConfig(t, "srv1") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // First reload. + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + tools1 := m.Tools() + require.Len(t, tools1, 1) + assert.Contains(t, tools1[0].Name, "srv1") + + // Rewrite config with a different server. + _, entry2 := fakeMCPServerConfig(t, "srv2") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2}) + + // Second reload detects the change. + assert.True(t, m.SnapshotChanged([]string{configPath})) + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + tools2 := m.Tools() + require.Len(t, tools2, 1) + assert.Contains(t, tools2[0].Name, "srv2") + }) + + t.Run("PerServerConnectFailureUpdatesSnapshot", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + // Config with a nonexistent binary: connect will fail. + path := filepath.Join(dir, ".mcp.json") + data := `{"mcpServers":{"bad":{"command":"/nonexistent/binary","args":[]}}}` + require.NoError(t, os.WriteFile(path, []byte(data), 0o600)) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // Reload should succeed (per-server failures are logged and + // swallowed) and snapshot should update. + err := m.Reload(ctx, []string{path}) + require.NoError(t, err) + assert.False(t, m.SnapshotChanged([]string{path}), + "snapshot should be updated even on per-server connect failure") + }) + + t.Run("EmptyConfigClosesServers", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + require.Len(t, m.Tools(), 1) + + // Delete config file. + require.NoError(t, os.Remove(configPath)) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + assert.Empty(t, m.Tools(), "tools should be empty after config deleted") + + // Subsequent reload finds snapshot unchanged. + assert.False(t, m.SnapshotChanged([]string{configPath})) + }) +} + +func TestDifferentialReload(t *testing.T) { + t.Parallel() + + // These tests verify differential reload behavior: client + // reuse for unchanged servers, reconnect for changed ones, + // and close for removed ones. + + t.Run("UnchangedServerReusesClient", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + // Capture the client pointer. + m.mu.RLock() + origClient := m.servers["srv"].client + m.mu.RUnlock() + require.NotNil(t, origClient) + + // Add a new server without changing the existing one. + _, entry2 := fakeMCPServerConfig(t, "srv2") + cfgMap := map[string]mcpServerEntry{"srv": entry, "srv2": entry2} + writeMCPConfig(t, dir, cfgMap) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + // The unchanged server should reuse the same client. + m.mu.RLock() + newClient := m.servers["srv"].client + m.mu.RUnlock() + assert.Same(t, origClient, newClient, + "unchanged server should reuse client pointer") + + // Both servers should have tools. + tools := m.Tools() + require.Len(t, tools, 2) + }) + + t.Run("ChangedServerGetsNewClient", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + m.mu.RLock() + origClient := m.servers["srv"].client + m.mu.RUnlock() + + // Change the server's args to trigger a diff. + entry.Args = append(entry.Args, "-test.v") + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + m.mu.RLock() + newClient := m.servers["srv"].client + m.mu.RUnlock() + assert.NotSame(t, origClient, newClient, + "changed server should get a new client") + }) + + t.Run("RemovedServerIsClosed", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entryA := fakeMCPServerConfig(t, "srvA") + _, entryB := fakeMCPServerConfig(t, "srvB") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{ + "srvA": entryA, "srvB": entryB, + }) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + require.Len(t, m.Tools(), 2) + + // Capture srvB's client before removal. + m.mu.RLock() + oldClientB := m.servers["srvB"].client + m.mu.RUnlock() + require.NotNil(t, oldClientB) + + // Remove srvB from the config. + writeMCPConfig(t, dir, map[string]mcpServerEntry{"srvA": entryA}) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + tools := m.Tools() + require.Len(t, tools, 1) + assert.Contains(t, tools[0].Name, "srvA") + + // The old client for srvB should be closed. + // ListTools on a closed client returns an error. + listCtx, cancel := context.WithTimeout(ctx, testutil.WaitShort) + defer cancel() + _, listErr := oldClientB.ListTools(listCtx, mcp.ListToolsRequest{}) + assert.Error(t, listErr, "ListTools on closed client should fail") + }) + + t.Run("ConnectFailureRetainsOldClient", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + require.Len(t, m.Tools(), 1) + + m.mu.RLock() + origClient := m.servers["srv"].client + m.mu.RUnlock() + + // Change config to use a bad command, so connect fails. + path := filepath.Join(dir, ".mcp.json") + data := `{"mcpServers":{"srv":{"command":"/nonexistent/binary","args":[]}}}` + require.NoError(t, os.WriteFile(path, []byte(data), 0o600)) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + // The old client should be retained because the new connect + // failed. + m.mu.RLock() + currentClient := m.servers["srv"].client + m.mu.RUnlock() + assert.Same(t, origClient, currentClient, + "failed connect should retain old client") + + // Tools should still work. + tools := m.Tools() + require.Len(t, tools, 1) + }) + + t.Run("PostReloadToolCallReachesKeptServer", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + tools := m.Tools() + require.Len(t, tools, 1) + toolName := tools[0].Name + + // Add a second server (srv unchanged, so client is reused). + _, entry2 := fakeMCPServerConfig(t, "srv2") + writeMCPConfig(t, dir, map[string]mcpServerEntry{ + "srv": entry, "srv2": entry2, + }) + + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + // A tool call to the kept server should reach it. + // The client pointer for "srv" was reused, not replaced. + _, err = m.CallTool(ctx, workspacesdk.CallMCPToolRequest{ + ToolName: toolName, + }) + // The fake server does not implement tools/call, so we + // expect an error from the server, but the call itself + // should reach the server (not ErrUnknownServer). + require.Error(t, err, "fake server does not implement tools/call") + assert.NotErrorIs(t, err, ErrUnknownServer, + "tool call should reach the server, not fail with unknown server") + }) +} + +// TestReload_FirstBootPath verifies that the first-boot call site +// (agent.go) can be routed through Reload without behavioral change. +func TestReload_FirstBootPath(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + // Simulate first-boot: Reload with the initial config. + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + tools := m.Tools() + require.Len(t, tools, 1) + assert.Contains(t, tools[0].Name, "echo") +} + +// TestReload_NoopWhenUnchanged verifies that Reload returns +// immediately without reconnecting when the snapshot is fresh. +func TestReload_NoopWhenUnchanged(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + m.mu.RLock() + origClient := m.servers["srv"].client + m.mu.RUnlock() + + // Second reload with no changes should be a no-op. + err = m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + + m.mu.RLock() + sameClient := m.servers["srv"].client + m.mu.RUnlock() + + assert.Same(t, origClient, sameClient, + "no-op reload should not replace the client") +} + +// TestClose_SuppressesSubprocessExitError verifies that Close +// returns nil when servers have running subprocesses that exit +// with a kill signal during shutdown. +func TestClose_SuppressesSubprocessExitError(t *testing.T) { + t.Parallel() + + if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" { + runFakeMCPServer() + return + } + + ctx := testutil.Context(t, testutil.WaitLong) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + dir := t.TempDir() + + _, entry := fakeMCPServerConfig(t, "srv") + configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry}) + + m := NewManager(ctx, logger, agentexec.DefaultExecer, nil) + t.Cleanup(func() { _ = m.Close() }) + + err := m.Reload(ctx, []string{configPath}) + require.NoError(t, err) + require.Len(t, m.Tools(), 1, "server should be connected") + + // Close kills the subprocess. The ExitError guard should + // suppress the "signal: killed" error. + err = m.Close() + assert.NoError(t, err, "Close should not propagate subprocess kill errors") +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 3a662e8f1d..d044cff7e8 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -249,8 +249,9 @@ func (p *Server) loadCachedWorkspaceContext( } var tools []fantasy.AgentTool + invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) } for _, t := range entry.tools { - tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn)) + tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn, invalidate)) } return tools @@ -6290,9 +6291,10 @@ func (p *Server) runChat( } } + invalidate := func() { p.workspaceMCPToolsCache.Delete(chat.ID) } for _, t := range toolsResp.Tools { workspaceMCPTools = append(workspaceMCPTools, - chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), + chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn, invalidate), ) } return nil diff --git a/coderd/x/chatd/chattool/mcpworkspace.go b/coderd/x/chatd/chattool/mcpworkspace.go index 5748c5b03f..1d2affc6d5 100644 --- a/coderd/x/chatd/chattool/mcpworkspace.go +++ b/coderd/x/chatd/chattool/mcpworkspace.go @@ -4,10 +4,13 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" + "net/http" "strings" "charm.land/fantasy" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" ) @@ -16,17 +19,22 @@ import ( // connection. It implements fantasy.AgentTool so it can be // registered alongside built-in chat tools. type WorkspaceMCPTool struct { - info fantasy.ToolInfo - getConn func(context.Context) (workspacesdk.AgentConn, error) - providerOpts fantasy.ProviderOptions + info fantasy.ToolInfo + getConn func(context.Context) (workspacesdk.AgentConn, error) + providerOpts fantasy.ProviderOptions + invalidateCache func() } // NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo // discovered on a workspace agent. Each tool proxies calls back -// through the agent connection. +// through the agent connection. The optional invalidateCache +// callback is invoked when CallMCPTool returns a 404 error, +// indicating that the server was removed and the chat's cached +// tool list should be dropped. func NewWorkspaceMCPTool( tool workspacesdk.MCPToolInfo, getConn func(context.Context) (workspacesdk.AgentConn, error), + invalidateCache func(), ) *WorkspaceMCPTool { required := tool.Required if required == nil { @@ -40,7 +48,8 @@ func NewWorkspaceMCPTool( Required: required, Parallel: true, }, - getConn: getConn, + getConn: getConn, + invalidateCache: invalidateCache, } } @@ -75,6 +84,15 @@ func (t *WorkspaceMCPTool) Run( Arguments: args, }) if err != nil { + // If the agent returns a 404 (ErrUnknownServer), the + // server was removed or renamed. Invalidate the chat's + // cached tool list so the next turn refetches. + var coderErr *codersdk.Error + if errors.As(err, &coderErr) && coderErr.StatusCode() == http.StatusNotFound { + if t.invalidateCache != nil { + t.invalidateCache() + } + } return fantasy.NewTextErrorResponse(err.Error()), nil } diff --git a/coderd/x/chatd/chattool/mcpworkspace_test.go b/coderd/x/chatd/chattool/mcpworkspace_test.go new file mode 100644 index 0000000000..4306509abd --- /dev/null +++ b/coderd/x/chatd/chattool/mcpworkspace_test.go @@ -0,0 +1,155 @@ +package chattool_test + +import ( + "context" + "net/http" + "sync/atomic" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// fakeAgentConn implements just enough of workspacesdk.AgentConn +// for testing CallMCPTool. +type fakeAgentConn struct { + workspacesdk.AgentConn + callMCPToolFunc func(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) +} + +func (f *fakeAgentConn) CallMCPTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + return f.callMCPToolFunc(ctx, req) +} + +func TestWorkspaceMCPTool_InvalidateOn404(t *testing.T) { + t.Parallel() + + t.Run("404ErrorInvalidatesCache", func(t *testing.T) { + t.Parallel() + + var invalidated atomic.Bool + tool := chattool.NewWorkspaceMCPTool( + workspacesdk.MCPToolInfo{ + Name: "test__echo", + Description: "test tool", + }, + func(ctx context.Context) (workspacesdk.AgentConn, error) { + return &fakeAgentConn{ + callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + return workspacesdk.CallMCPToolResponse{}, codersdk.NewError( + http.StatusNotFound, + codersdk.Response{ + Message: "MCP tool call failed.", + Detail: `unknown MCP server: "test"`, + }, + ) + }, + }, nil + }, + func() { invalidated.Store(true) }, + ) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{}) + require.NoError(t, err) + assert.True(t, resp.IsError, "response should be an error") + assert.True(t, invalidated.Load(), + "invalidateCache should fire on 404") + }) + + t.Run("Non404DoesNotInvalidate", func(t *testing.T) { + t.Parallel() + + var invalidated atomic.Bool + tool := chattool.NewWorkspaceMCPTool( + workspacesdk.MCPToolInfo{ + Name: "test__echo", + Description: "test tool", + }, + func(ctx context.Context) (workspacesdk.AgentConn, error) { + return &fakeAgentConn{ + callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + return workspacesdk.CallMCPToolResponse{}, codersdk.NewError( + http.StatusBadGateway, + codersdk.Response{ + Message: "Bad Gateway", + }, + ) + }, + }, nil + }, + func() { invalidated.Store(true) }, + ) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{}) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.False(t, invalidated.Load(), + "invalidateCache should NOT fire on non-404 error") + }) + + t.Run("ToolLevelErrorNoInvalidation", func(t *testing.T) { + t.Parallel() + + var invalidated atomic.Bool + tool := chattool.NewWorkspaceMCPTool( + workspacesdk.MCPToolInfo{ + Name: "test__echo", + Description: "test tool", + }, + func(ctx context.Context) (workspacesdk.AgentConn, error) { + return &fakeAgentConn{ + callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + return workspacesdk.CallMCPToolResponse{ + IsError: true, + Content: []workspacesdk.MCPToolContent{ + {Type: "text", Text: "tool error"}, + }, + }, nil + }, + }, nil + }, + func() { invalidated.Store(true) }, + ) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{}) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.False(t, invalidated.Load(), + "invalidateCache should NOT fire on tool-level error (HTTP 200)") + }) + + t.Run("NilInvalidateCallbackSafe", func(t *testing.T) { + t.Parallel() + + tool := chattool.NewWorkspaceMCPTool( + workspacesdk.MCPToolInfo{ + Name: "test__echo", + Description: "test tool", + }, + func(ctx context.Context) (workspacesdk.AgentConn, error) { + return &fakeAgentConn{ + callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + return workspacesdk.CallMCPToolResponse{}, codersdk.NewError( + http.StatusNotFound, + codersdk.Response{ + Message: "MCP tool call failed.", + Detail: `unknown MCP server: "test"`, + }, + ) + }, + }, nil + }, + nil, + ) + + // Should not panic. + resp, err := tool.Run(context.Background(), fantasy.ToolCall{}) + require.NoError(t, err) + assert.True(t, resp.IsError) + }) +}