feat: reload MCP config on change via lazy stat-on-request (#24700)

The MCP manager previously read .mcp.json exactly once at agent startup.
Editing the file had no effect until workspace rebuild or agent restart.

handleListTools now stats config file mtimes on every tool-list request
and triggers a differential reload when any file changed. Unchanged
servers keep their client pointer so in-flight tool calls survive.
Concurrent reload requests coalesce via singleflight.

MCP stdio subprocesses use the agent's execer for resource limits and
receive the same enriched environment as SSH sessions via updateEnv.

On the chatd side, WorkspaceMCPTool.Run detects 404 responses from
CallMCPTool (indicating the server was removed) and drops the chat's
cached tool list so the next turn refetches from the agent.
This commit is contained in:
Mathias Fredriksson
2026-04-28 19:47:14 +03:00
committed by GitHub
parent 3f0e015fe5
commit 881df9a5b0
9 changed files with 1658 additions and 219 deletions
+4 -4
View File
@@ -423,14 +423,14 @@ func (a *agent) init() {
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil, a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
) )
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock) a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp")) a.mcpManager = agentmcp.NewManager(a.gracefulCtx, a.logger.Named("mcp"), a.execer, a.updateCommandEnv)
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string { a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
if m := a.manifest.Load(); m != nil { if m := a.manifest.Load(); m != nil {
return m.Directory return m.Directory
} }
return "" return ""
}, a.contextConfig) }, a.contextConfig)
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager, a.contextConfigAPI.MCPConfigFiles)
a.reconnectingPTYServer = reconnectingpty.NewServer( a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"), a.logger.Named("reconnecting-pty"),
a.sshServer, a.sshServer,
@@ -1413,8 +1413,8 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
// lifecycle transition to avoid delaying Ready. // lifecycle transition to avoid delaying Ready.
// This runs inside the tracked goroutine so it // This runs inside the tracked goroutine so it
// is properly awaited on shutdown. // is properly awaited on shutdown.
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil { if mcpErr := a.mcpManager.Reload(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr)) a.logger.Warn(ctx, "failed to reload workspace MCP servers", slog.Error(mcpErr))
} }
}) })
if err != nil { if err != nil {
+43 -9
View File
@@ -1,6 +1,7 @@
package agentmcp package agentmcp
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
@@ -15,16 +16,24 @@ import (
// API exposes MCP tool discovery and call proxying through the // API exposes MCP tool discovery and call proxying through the
// agent. // agent.
type API struct { type API struct {
logger slog.Logger logger slog.Logger
manager *Manager manager *Manager
mcpConfigFiles func() []string
} }
// NewAPI creates a new MCP API handler backed by the given // NewAPI creates a new MCP API handler backed by the given
// manager. // manager. The mcpConfigFiles callback returns the current
func NewAPI(logger slog.Logger, manager *Manager) *API { // resolved config file paths; it is called on every tool-list
// request to detect config changes.
func NewAPI(
logger slog.Logger,
manager *Manager,
mcpConfigFiles func() []string,
) *API {
return &API{ return &API{
logger: logger, logger: logger,
manager: manager, manager: manager,
mcpConfigFiles: mcpConfigFiles,
} }
} }
@@ -36,13 +45,38 @@ func (api *API) Routes() http.Handler {
return r return r
} }
// handleListTools returns the cached MCP tool definitions, // handleListTools checks whether any .mcp.json config file
// optionally refreshing them first if ?refresh=true is set. // has changed since the last reload, triggering a differential
// reload if so, then returns the cached MCP tool definitions.
// The ?refresh=true query parameter forces a tool re-scan
// independent of config changes.
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
// Check config freshness and reload if changed.
var reloaded bool
paths := api.mcpConfigFiles()
if api.manager.SnapshotChanged(paths) {
if err := api.manager.Reload(ctx, paths); err != nil {
// Categorize the error for operator debugging.
switch {
case errors.Is(err, context.Canceled):
api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err))
case errors.Is(err, context.DeadlineExceeded):
api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err))
default:
api.logger.Warn(ctx, "mcp reload failed", slog.Error(err))
}
// Fall through to return whatever tools we have.
} else {
reloaded = true
}
}
// Allow callers to force a tool re-scan before listing. // Allow callers to force a tool re-scan before listing.
if r.URL.Query().Get("refresh") == "true" { // Skip if a config reload ran above, since it already
// refreshes tools as part of the reload.
if r.URL.Query().Get("refresh") == "true" && !reloaded {
if err := api.manager.RefreshTools(ctx); err != nil { if err := api.manager.RefreshTools(ctx); err != nil {
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
} }
+228
View File
@@ -0,0 +1,228 @@
package agentmcp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
func TestHandleListTools_ReloadOnChange(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
// Cases that share the single-request-and-check pattern.
type singleRequestCase struct {
name string
entries func(t *testing.T) map[string]mcpServerEntry
reloadManager bool
closeManager bool
expectedTools int
toolNameContains string
}
cases := []singleRequestCase{
{
name: "InitialRequestNoReload",
entries: func(t *testing.T) map[string]mcpServerEntry {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
return map[string]mcpServerEntry{"srv": entry}
},
reloadManager: true,
expectedTools: 1,
toolNameContains: "echo",
},
{
name: "ManagerClosedReturnsEmpty",
entries: func(_ *testing.T) map[string]mcpServerEntry {
return map[string]mcpServerEntry{}
},
closeManager: true,
expectedTools: 0,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
configPath := writeMCPConfig(t, dir, tc.entries(t))
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
if tc.closeManager {
require.NoError(t, m.Close())
} else {
t.Cleanup(func() { _ = m.Close() })
}
if tc.reloadManager {
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
}
api := NewAPI(logger, m, func() []string {
return []string{configPath}
})
req := httptest.NewRequest(http.MethodGet, "/tools", nil)
rec := httptest.NewRecorder()
api.Routes().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp workspacesdk.ListMCPToolsResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
require.Len(t, resp.Tools, tc.expectedTools)
if tc.toolNameContains != "" {
assert.Contains(t, resp.Tools[0].Name, tc.toolNameContains)
}
})
}
// ConfigChangeTriggersReload has a mutate-then-re-request flow
// that does not fit the single-request table pattern.
t.Run("ConfigChangeTriggersReload", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry1 := fakeMCPServerConfig(t, "srv1")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
api := NewAPI(logger, m, func() []string {
return []string{configPath}
})
// Verify initial tools.
req := httptest.NewRequest(http.MethodGet, "/tools", nil)
rec := httptest.NewRecorder()
api.Routes().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp1 workspacesdk.ListMCPToolsResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp1))
require.Len(t, resp1.Tools, 1)
assert.Contains(t, resp1.Tools[0].Name, "srv1")
// Mutate the config file.
_, entry2 := fakeMCPServerConfig(t, "srv2")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
// Next request should trigger a reload and return new tools.
req2 := httptest.NewRequest(http.MethodGet, "/tools", nil)
rec2 := httptest.NewRecorder()
api.Routes().ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
var resp2 workspacesdk.ListMCPToolsResponse
require.NoError(t, json.NewDecoder(rec2.Body).Decode(&resp2))
require.Len(t, resp2.Tools, 1)
assert.Contains(t, resp2.Tools[0].Name, "srv2")
})
}
func TestHandleListTools_RefreshParam(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
t.Run("RefreshTrueUnchangedSnapshot", func(t *testing.T) {
// Exercises the ?refresh=true code path when the config
// snapshot is unchanged. Verifies the endpoint returns
// tools without error.
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
api := NewAPI(logger, m, func() []string {
return []string{configPath}
})
req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil)
rec := httptest.NewRecorder()
api.Routes().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp workspacesdk.ListMCPToolsResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
// Tool should still be present after refresh.
require.Len(t, resp.Tools, 1)
assert.Contains(t, resp.Tools[0].Name, "echo")
})
t.Run("RefreshTrueWithChangedConfig", func(t *testing.T) {
// Exercises the ?refresh=true code path when the config
// has also changed. The reload path already calls
// RefreshTools, so the handler skips the redundant call.
// This test covers the branch; it cannot observe the
// skip without a mock.
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry1 := fakeMCPServerConfig(t, "srv1")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
api := NewAPI(logger, m, func() []string {
return []string{configPath}
})
// Mutate config.
_, entry2 := fakeMCPServerConfig(t, "srv2")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil)
rec := httptest.NewRecorder()
api.Routes().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp workspacesdk.ListMCPToolsResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
require.Len(t, resp.Tools, 1)
assert.Contains(t, resp.Tools[0].Name, "srv2")
})
}
+491 -198
View File
@@ -5,7 +5,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"maps"
"os" "os"
"os/exec"
"reflect"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -16,8 +19,11 @@ import (
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/xerrors" "golang.org/x/xerrors"
tailscalesingleflight "tailscale.com/util/singleflight"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
) )
@@ -44,15 +50,30 @@ var (
ErrUnknownServer = xerrors.New("unknown MCP server") ErrUnknownServer = xerrors.New("unknown MCP server")
) )
// fileSnapshot records the identity of a config file at the time
// it was last read.
type fileSnapshot struct {
exists bool
modTime time.Time
size int64
}
// Manager manages connections to MCP servers discovered from a // Manager manages connections to MCP servers discovered from a
// workspace's .mcp.json file. It caches the aggregated tool list // workspace's .mcp.json file. It caches the aggregated tool list
// and proxies tool calls to the appropriate server. // and proxies tool calls to the appropriate server.
type Manager struct { type Manager struct {
mu sync.RWMutex ctx context.Context
logger slog.Logger execer agentexec.Execer
closed bool updateEnv func(current []string) ([]string, error)
servers map[string]*serverEntry // keyed by server name
tools []workspacesdk.MCPToolInfo mu sync.RWMutex
logger slog.Logger
closed bool
servers map[string]*serverEntry
tools []workspacesdk.MCPToolInfo
snapshot map[string]fileSnapshot
serverGen uint64
sf tailscalesingleflight.Group[string, struct{}]
} }
// serverEntry pairs a server config with its connected client. // serverEntry pairs a server config with its connected client.
@@ -61,18 +82,189 @@ type serverEntry struct {
client *client.Client client *client.Client
} }
// NewManager creates a new MCP client manager. // NewManager creates a new MCP client manager. The ctx bounds
func NewManager(logger slog.Logger) *Manager { // subprocess lifetime. The execer applies resource limits to
// MCP server subprocesses. The updateEnv callback enriches the
// subprocess environment to match interactive sessions.
func NewManager(
ctx context.Context,
logger slog.Logger,
execer agentexec.Execer,
updateEnv func([]string) ([]string, error),
) *Manager {
return &Manager{ return &Manager{
logger: logger, ctx: ctx,
servers: make(map[string]*serverEntry), logger: logger,
execer: execer,
updateEnv: updateEnv,
servers: make(map[string]*serverEntry),
snapshot: make(map[string]fileSnapshot),
} }
} }
// Connect reads MCP config files at the given absolute paths and // Reload checks whether config files have changed and, if so,
// connects to all configured servers. Failed servers are logged // performs a differential reconnect. Concurrent callers are
// and skipped. Missing config files are silently skipped. // coalesced via singleflight; the reload body runs under the
func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error { // Manager's lifetime context so it survives caller cancellation.
func (m *Manager) Reload(ctx context.Context, paths []string) error {
m.mu.RLock()
closed := m.closed
hasSnapshot := len(m.snapshot) > 0
m.mu.RUnlock()
if closed {
return xerrors.New("manager closed")
}
// Double-check: another goroutine may have completed a
// reload between the caller's SnapshotChanged and this
// call. The singleflight body uses its own resolved paths.
if hasSnapshot && !m.SnapshotChanged(paths) {
return nil
}
// All concurrent callers share one in-flight reload keyed
// by "". If a concurrent caller resolves different paths
// (e.g. after a manifest reconnect), its paths are not
// consulted; the next SnapshotChanged check after this
// reload completes will detect the mismatch and trigger
// a fresh reload.
ch := m.sf.DoChan("reload", func() (struct{}, error) {
err := m.doReload(m.ctx, paths)
return struct{}{}, err
})
select {
case <-ctx.Done():
return ctx.Err()
case res := <-ch:
return res.Err
}
}
// SnapshotChanged checks whether any config file has changed
// since the last reload by comparing os.Stat results against
// the stored snapshot.
func (m *Manager) SnapshotChanged(paths []string) bool {
seen := make(map[string]struct{}, len(paths))
unique := make([]string, 0, len(paths))
for _, p := range paths {
if _, ok := seen[p]; !ok {
seen[p] = struct{}{}
unique = append(unique, p)
}
}
paths = unique
m.mu.RLock()
snap := maps.Clone(m.snapshot)
snapshotLen := len(snap)
m.mu.RUnlock()
if len(paths) != snapshotLen {
return true
}
for _, p := range paths {
prev, ok := snap[p]
if !ok {
return true
}
info, err := os.Stat(p)
if err != nil {
// Stat failed; changed only if the file existed before.
if prev.exists {
return true
}
continue
}
// Stat succeeded but file was absent before: it appeared.
if !prev.exists {
return true
}
if !info.ModTime().Equal(prev.modTime) || info.Size() != prev.size {
return true
}
}
return false
}
// serverDiff is the output of classifyServers: which servers to
// connect, which to close, which to keep, and a snapshot of the
// previous map for fallback on connect failure.
type serverDiff struct {
toConnect []ServerConfig
toClose []*serverEntry
keep map[string]*serverEntry
prev map[string]*serverEntry
}
type connectedServer struct {
name string
config ServerConfig
client *client.Client
}
// doReload reads MCP config files and performs a differential
// reconnect. Unchanged servers keep their existing client; new or
// changed servers get a fresh connection; removed servers are
// closed.
func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
allConfigs, snap := m.parseAndDedup(ctx, mcpConfigFiles)
wanted := make(map[string]ServerConfig, len(allConfigs))
for _, cfg := range allConfigs {
wanted[cfg.Name] = cfg
}
diff, err := m.classifyServers(wanted)
if err != nil {
return err
}
connected := m.connectAll(ctx, diff.toConnect)
replaced, err := m.installServers(wanted, diff, connected, snap)
if err != nil {
return err
}
// Close removed and replaced servers outside the lock to
// avoid leaking child processes and to avoid blocking
// concurrent readers on subprocess I/O.
// Note: a concurrent CallTool that captured a removed
// entry's client before the swap may call a closed client.
// This is a narrow race that self-heals on the next request.
for _, entry := range diff.toClose {
_ = entry.client.Close()
}
for _, entry := range replaced {
_ = entry.client.Close()
}
// Refresh tools outside the lock to avoid blocking
// concurrent reads during network I/O.
if err := m.RefreshTools(ctx); err != nil {
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
}
return nil
}
// parseAndDedup reads all config files and returns a deduplicated
// list of server configs. Missing files are silently skipped;
// parse errors are logged and skipped.
func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) {
// Stat before reading so the snapshot is conservatively old.
// If a file changes between stat and read, the snapshot
// records the old mtime, SnapshotChanged detects a mismatch
// on the next check, and triggers a re-read. False positives
// (extra reload) are safe; false negatives (missed change)
// are not.
snap := captureSnapshot(mcpConfigFiles)
var allConfigs []ServerConfig var allConfigs []ServerConfig
for _, configPath := range mcpConfigFiles { for _, configPath := range mcpConfigFiles {
configs, err := ParseConfig(configPath) configs, err := ParseConfig(configPath)
@@ -99,26 +291,55 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
seen[cfg.Name] = struct{}{} seen[cfg.Name] = struct{}{}
deduped = append(deduped, cfg) deduped = append(deduped, cfg)
} }
allConfigs = deduped return deduped, snap
}
if len(allConfigs) == 0 { // classifyServers compares wanted configs against the current
return nil // server map and returns a diff describing what changed.
// Acquires and releases m.mu for reading.
func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil, xerrors.New("manager closed")
} }
// Connect to servers in parallel without holding the diff := &serverDiff{
// lock, since each connectServer call may block on keep: make(map[string]*serverEntry),
// network I/O for up to connectTimeout.
type connectedServer struct {
name string
config ServerConfig
client *client.Client
} }
for name, wantCfg := range wanted {
if existing, ok := m.servers[name]; ok {
if reflect.DeepEqual(existing.config, wantCfg) {
diff.keep[name] = existing
} else {
diff.toConnect = append(diff.toConnect, wantCfg)
}
} else {
diff.toConnect = append(diff.toConnect, wantCfg)
}
}
for name, entry := range m.servers {
if _, ok := wanted[name]; !ok {
diff.toClose = append(diff.toClose, entry)
}
}
diff.prev = maps.Clone(m.servers)
return diff, nil
}
// connectAll runs connectServer in parallel for the given configs.
// Failed connects are logged and skipped.
func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer {
var ( var (
mu sync.Mutex mu sync.Mutex
connected []connectedServer connected []connectedServer
) )
var eg errgroup.Group var eg errgroup.Group
for _, cfg := range allConfigs { for _, cfg := range toConnect {
eg.Go(func() error { eg.Go(func() error {
c, err := m.connectServer(ctx, cfg) c, err := m.connectServer(ctx, cfg)
if err != nil { if err != nil {
@@ -138,131 +359,81 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
}) })
} }
_ = eg.Wait() _ = eg.Wait()
return connected
}
// installServers builds the new server map from diff.keep and the
// connected list, falling back to diff.prev when a connect failed.
// Returns old entries replaced by successful connects (caller
// closes them). Acquires and releases m.mu.
func (m *Manager) installServers(
wanted map[string]ServerConfig,
diff *serverDiff,
connected []connectedServer,
snap map[string]fileSnapshot,
) ([]*serverEntry, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock()
if m.closed { if m.closed {
m.mu.Unlock()
// Close the freshly-connected clients since we're
// shutting down.
for _, cs := range connected { for _, cs := range connected {
_ = cs.client.Close() _ = cs.client.Close()
} }
return xerrors.New("manager closed") return nil, xerrors.New("manager closed")
} }
// Close previous connections to avoid leaking child newConnected := make(map[string]connectedServer, len(connected))
// processes on agent reconnect.
for _, entry := range m.servers {
_ = entry.client.Close()
}
m.servers = make(map[string]*serverEntry, len(connected))
for _, cs := range connected { for _, cs := range connected {
m.servers[cs.name] = &serverEntry{ newConnected[cs.name] = cs
config: cs.config, }
client: cs.client,
newServers := make(map[string]*serverEntry, len(wanted))
for name, entry := range diff.keep {
newServers[name] = entry
}
var replaced []*serverEntry
for name, wantCfg := range wanted {
if _, kept := diff.keep[name]; kept {
continue
} }
} if cs, ok := newConnected[wantCfg.Name]; ok {
m.mu.Unlock() newServers[wantCfg.Name] = &serverEntry{
config: cs.config,
// Refresh tools outside the lock to avoid blocking client: cs.client,
// concurrent reads during network I/O. }
if err := m.RefreshTools(ctx); err != nil { if prev, existed := diff.prev[wantCfg.Name]; existed {
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) replaced = append(replaced, prev)
} }
return nil } else if prev, existed := diff.prev[wantCfg.Name]; existed {
} // Connect failed; retain the old client.
newServers[wantCfg.Name] = prev
// connectServer establishes a connection to a single MCP server
// and returns the connected client. It does not modify any Manager
// state.
func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
tr, err := createTransport(cfg)
if err != nil {
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
}
c := client.NewClient(tr)
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
// Use the parent ctx (not connectCtx) so the subprocess outlives
// the connect/initialize handshake. connectCtx bounds only the
// Initialize call below. The subprocess is cleaned up when the
// Manager is closed or ctx is canceled.
if err := c.Start(ctx); err != nil {
_ = c.Close()
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
}
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder-agent",
Version: buildinfo.Version(),
},
},
})
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
}
return c, nil
}
// createTransport builds the mcp-go transport for a server config.
func createTransport(cfg ServerConfig) (transport.Interface, error) {
switch cfg.Transport {
case "stdio":
return transport.NewStdio(
cfg.Command,
buildEnv(cfg.Env),
cfg.Args...,
), nil
case "http", "":
return transport.NewStreamableHTTP(
cfg.URL,
transport.WithHTTPHeaders(cfg.Headers),
)
case "sse":
return transport.NewSSE(
cfg.URL,
transport.WithHeaders(cfg.Headers),
)
default:
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
}
}
// buildEnv merges the current process environment with explicit
// overrides, returning the result as KEY=VALUE strings suitable
// for the stdio transport.
func buildEnv(explicit map[string]string) []string {
env := os.Environ()
if len(explicit) == 0 {
return env
}
// Index existing env so explicit keys can override in-place.
existing := make(map[string]int, len(env))
for i, kv := range env {
if k, _, ok := strings.Cut(kv, "="); ok {
existing[k] = i
} }
} }
for k, v := range explicit { m.servers = newServers
entry := k + "=" + v m.serverGen++
if idx, ok := existing[k]; ok { m.snapshot = snap
env[idx] = entry return replaced, nil
} else { }
env = append(env, entry)
// captureSnapshot stats each path and returns the current
// snapshot map.
func captureSnapshot(paths []string) map[string]fileSnapshot {
snap := make(map[string]fileSnapshot, len(paths))
for _, p := range paths {
info, err := os.Stat(p)
if err != nil {
snap[p] = fileSnapshot{exists: false}
continue
}
snap[p] = fileSnapshot{
exists: true,
modTime: info.ModTime(),
size: info.Size(),
} }
} }
return env return snap
} }
// Tools returns the cached tool list. Thread-safe. // Tools returns the cached tool list. Thread-safe.
@@ -304,68 +475,6 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ
return convertResult(result), nil return convertResult(result), nil
} }
// splitToolName extracts the server name and original tool name
// from a prefixed tool name like "server__tool".
func splitToolName(prefixed string) (serverName, toolName string, err error) {
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
if !ok || server == "" || tool == "" {
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
}
return server, tool, nil
}
// convertResult translates an MCP CallToolResult into a
// workspacesdk.CallMCPToolResponse. It iterates over content
// items and maps each recognized type.
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
if result == nil {
return workspacesdk.CallMCPToolResponse{}
}
var content []workspacesdk.MCPToolContent
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: c.Text,
})
case mcp.ImageContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "image",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.AudioContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "audio",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.EmbeddedResource:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
})
case mcp.ResourceLink:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[resource link: %s]", c.URI),
})
default:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: fmt.Sprintf("[unsupported content type: %T]", item),
})
}
}
return workspacesdk.CallMCPToolResponse{
Content: content,
IsError: result.IsError,
}
}
// RefreshTools re-fetches tool lists from all connected servers // RefreshTools re-fetches tool lists from all connected servers
// in parallel and rebuilds the cache. On partial failure, tools // in parallel and rebuilds the cache. On partial failure, tools
// from servers that responded successfully are merged with the // from servers that responded successfully are merged with the
@@ -378,6 +487,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
for k, v := range m.servers { for k, v := range m.servers {
servers[k] = v servers[k] = v
} }
gen := m.serverGen
m.mu.RUnlock() m.mu.RUnlock()
// Fetch tool lists in parallel without holding any lock. // Fetch tool lists in parallel without holding any lock.
@@ -451,7 +561,12 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
}) })
m.mu.Lock() m.mu.Lock()
m.tools = merged // Skip the write if the server map changed since the
// snapshot. A doReload that bumped the generation will
// produce a correct tool list; this write would be stale.
if m.serverGen == gen {
m.tools = merged
}
m.mu.Unlock() m.mu.Unlock()
return errors.Join(errs...) return errors.Join(errs...)
@@ -466,9 +581,187 @@ func (m *Manager) Close() error {
m.closed = true m.closed = true
var errs []error var errs []error
for _, entry := range m.servers { for _, entry := range m.servers {
errs = append(errs, entry.client.Close()) if err := entry.client.Close(); err != nil {
// Subprocess kill signals are expected during shutdown.
// The stdio transport returns cmd.Wait() which surfaces
// "signal: killed" as an exec.ExitError.
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
errs = append(errs, err)
}
}
} }
m.servers = make(map[string]*serverEntry) m.servers = make(map[string]*serverEntry)
m.tools = nil m.tools = nil
return errors.Join(errs...) return errors.Join(errs...)
} }
// connectServer establishes a connection to a single MCP server
// and returns the connected client. It does not modify any Manager
// state.
func (m *Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
tr, err := m.createTransport(ctx, cfg)
if err != nil {
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
}
c := client.NewClient(tr)
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
// Use the parent ctx (not connectCtx) so the subprocess outlives
// the connect/initialize handshake. connectCtx bounds only the
// Initialize call below. The subprocess is cleaned up when the
// Manager is closed or ctx is canceled.
if err := c.Start(ctx); err != nil {
_ = c.Close()
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
}
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder-agent",
Version: buildinfo.Version(),
},
},
})
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
}
return c, nil
}
// createTransport builds the mcp-go transport for a server config.
func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transport.Interface, error) {
switch cfg.Transport {
case "stdio":
env := m.buildEnv(ctx, cfg.Env)
return transport.NewStdioWithOptions(
cfg.Command,
env,
cfg.Args,
transport.WithCommandFunc(func(ctx context.Context, command string, cmdEnv []string, args []string) (*exec.Cmd, error) {
cmd := m.execer.CommandContext(ctx, command, args...)
cmd.Env = cmdEnv
return cmd, nil
}),
), nil
case "http", "":
return transport.NewStreamableHTTP(
cfg.URL,
transport.WithHTTPHeaders(cfg.Headers),
)
case "sse":
return transport.NewSSE(
cfg.URL,
transport.WithHeaders(cfg.Headers),
)
default:
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
}
}
// buildEnv enriches the process environment via the agent's
// updateEnv callback, then merges explicit overrides from the
// server config on top.
func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string {
env := usershell.SystemEnvInfo{}.Environ()
if m.updateEnv != nil {
var err error
env, err = m.updateEnv(env)
if err != nil {
m.logger.Warn(ctx, "failed to enrich MCP server environment",
slog.Error(err),
)
env = usershell.SystemEnvInfo{}.Environ()
}
}
if len(explicit) == 0 {
return env
}
// Index existing env so explicit keys can override in-place.
existing := make(map[string]int, len(env))
for i, kv := range env {
if k, _, ok := strings.Cut(kv, "="); ok {
existing[k] = i
}
}
for k, v := range explicit {
entry := k + "=" + v
if idx, ok := existing[k]; ok {
env[idx] = entry
} else {
env = append(env, entry)
}
}
return env
}
// splitToolName extracts the server name and original tool name
// from a prefixed tool name like "server__tool".
func splitToolName(prefixed string) (serverName, toolName string, err error) {
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
if !ok || server == "" || tool == "" {
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
}
return server, tool, nil
}
// convertResult translates an MCP CallToolResult into a
// workspacesdk.CallMCPToolResponse. It iterates over content
// items and maps each recognized type.
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
if result == nil {
return workspacesdk.CallMCPToolResponse{}
}
var content []workspacesdk.MCPToolContent
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: c.Text,
})
case mcp.ImageContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "image",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.AudioContent:
content = append(content, workspacesdk.MCPToolContent{
Type: "audio",
Data: c.Data,
MediaType: c.MIMEType,
})
case mcp.EmbeddedResource:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
})
case mcp.ResourceLink:
content = append(content, workspacesdk.MCPToolContent{
Type: "resource",
Text: fmt.Sprintf("[resource link: %s]", c.URI),
})
default:
content = append(content, workspacesdk.MCPToolContent{
Type: "text",
Text: fmt.Sprintf("[unsupported content type: %T]", item),
})
}
}
return workspacesdk.CallMCPToolResponse{
Content: content,
IsError: result.IsError,
}
}
+2 -1
View File
@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@@ -227,7 +228,7 @@ func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
} }
ctx := testutil.Context(t, testutil.WaitLong) ctx := testutil.Context(t, testutil.WaitLong)
m := &Manager{} m := &Manager{execer: agentexec.DefaultExecer}
client, err := m.connectServer(ctx, cfg) client, err := m.connectServer(ctx, cfg)
require.NoError(t, err, "connectServer should succeed") require.NoError(t, err, "connectServer should succeed")
t.Cleanup(func() { _ = client.Close() }) t.Cleanup(func() { _ = client.Close() })
+708
View File
@@ -0,0 +1,708 @@
package agentmcp
import (
"context"
"encoding/json"
"os"
"path/filepath"
"sync"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
// writeMCPConfig writes a .mcp.json file with the given server
// entries. Each entry maps a server name to its config.
func writeMCPConfig(t *testing.T, dir string, servers map[string]mcpServerEntry) string {
t.Helper()
path := filepath.Join(dir, ".mcp.json")
cfg := mcpConfigFile{MCPServers: make(map[string]json.RawMessage)}
for name, entry := range servers {
raw, err := json.Marshal(entry)
require.NoError(t, err)
cfg.MCPServers[name] = raw
}
data, err := json.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(path, data, 0o600)
require.NoError(t, err)
return path
}
// fakeMCPServerConfig returns a ServerConfig that launches a fake
// MCP server using the test binary re-exec pattern.
func fakeMCPServerConfig(t *testing.T, name string) (ServerConfig, mcpServerEntry) {
t.Helper()
testBin, err := os.Executable()
require.NoError(t, err)
cfg := ServerConfig{
Name: name,
Transport: "stdio",
Command: testBin,
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
}
entry := mcpServerEntry{
Command: testBin,
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
}
return cfg, entry
}
func TestSnapshotChanged(t *testing.T) {
t.Parallel()
type testCase struct {
name string
setup func(t *testing.T, dir string) []string
mutate func(t *testing.T, dir string)
checkPaths func(t *testing.T, dir string, initialPaths []string) []string
want bool
}
cases := []testCase{
{
name: "UnchangedFiles",
setup: func(t *testing.T, dir string) []string {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
return []string{configPath}
},
want: false,
},
{
name: "ContentChange",
setup: func(t *testing.T, dir string) []string {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
return []string{configPath}
},
mutate: func(t *testing.T, dir string) {
t.Helper()
_, entry2 := fakeMCPServerConfig(t, "srv2")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
},
want: true,
},
{
name: "FileBecomesMissing",
setup: func(t *testing.T, dir string) []string {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
return []string{configPath}
},
mutate: func(t *testing.T, dir string) {
t.Helper()
require.NoError(t, os.Remove(filepath.Join(dir, ".mcp.json")))
},
want: true,
},
{
name: "FileAppears",
setup: func(t *testing.T, dir string) []string {
t.Helper()
return []string{filepath.Join(dir, ".mcp.json")}
},
mutate: func(t *testing.T, dir string) {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
},
want: true,
},
{
name: "BothAbsentUnchanged",
setup: func(t *testing.T, dir string) []string {
t.Helper()
return []string{filepath.Join(dir, ".mcp.json")}
},
want: false,
},
{
name: "PathSetDiffers",
setup: func(t *testing.T, dir string) []string {
t.Helper()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
return []string{configPath}
},
checkPaths: func(t *testing.T, dir string, initialPaths []string) []string {
t.Helper()
extraPath := filepath.Join(dir, "extra.mcp.json")
return append(initialPaths, extraPath)
},
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
paths := tc.setup(t, dir)
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, paths)
require.NoError(t, err)
if tc.mutate != nil {
tc.mutate(t, dir)
}
checkPaths := paths
if tc.checkPaths != nil {
checkPaths = tc.checkPaths(t, dir, paths)
}
changed := m.SnapshotChanged(checkPaths)
assert.Equal(t, tc.want, changed)
})
}
}
func TestSnapshotChanged_MultipleConfigFiles(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir1 := t.TempDir()
dir2 := t.TempDir()
_, entry1 := fakeMCPServerConfig(t, "srv1")
_, entry2 := fakeMCPServerConfig(t, "srv2")
path1 := writeMCPConfig(t, dir1, map[string]mcpServerEntry{"srv1": entry1})
path2 := writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2": entry2})
paths := []string{path1, path2}
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// Initial reload with both config files.
err := m.Reload(ctx, paths)
require.NoError(t, err)
// Both files unchanged.
assert.False(t, m.SnapshotChanged(paths),
"snapshot should not change when both files are unchanged")
// Mutate only the second file.
_, entry2b := fakeMCPServerConfig(t, "srv2b")
writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2b": entry2b})
assert.True(t, m.SnapshotChanged(paths),
"snapshot should change when second file is mutated")
// Reload picks up the mutation.
err = m.Reload(ctx, paths)
require.NoError(t, err)
// Tools from both files should be present.
tools := m.Tools()
require.Len(t, tools, 2, "should have tools from both config files")
assert.Contains(t, tools[0].Name, "srv1",
"first tool should be from first config")
assert.Contains(t, tools[1].Name, "srv2b",
"second tool should be from second config")
}
func TestReload(t *testing.T) {
t.Parallel()
t.Run("SingleReloadUpdatesSnapshot", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools := m.Tools()
require.Len(t, tools, 1, "should have one tool from the fake server")
assert.Contains(t, tools[0].Name, "echo")
// Snapshot should be fresh.
assert.False(t, m.SnapshotChanged([]string{configPath}))
})
t.Run("ReloadAfterClose", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
require.NoError(t, m.Close())
err := m.Reload(ctx, []string{"/nonexistent"})
require.Error(t, err, "reload after close should fail")
})
t.Run("ConcurrentReloadsCoalesce", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// Launch multiple concurrent reloads.
const numCallers = 5
var wg sync.WaitGroup
errs := make([]error, numCallers)
for i := range numCallers {
wg.Go(func() {
errs[i] = m.Reload(ctx, []string{configPath})
})
}
wg.Wait()
for i, err := range errs {
assert.NoError(t, err, "caller %d should not fail", i)
}
tools := m.Tools()
require.Len(t, tools, 1)
})
t.Run("CallerContextCanceled", func(t *testing.T) {
t.Parallel()
mgrCtx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(mgrCtx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// Use an already-canceled caller context.
callerCtx, cancel := context.WithCancel(mgrCtx)
cancel() // Cancel immediately.
err := m.Reload(callerCtx, []string{configPath})
// The caller context is already canceled, so Reload should
// return the caller's context error.
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
})
t.Run("SequentialReloadsDiffDetect", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry1 := fakeMCPServerConfig(t, "srv1")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// First reload.
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools1 := m.Tools()
require.Len(t, tools1, 1)
assert.Contains(t, tools1[0].Name, "srv1")
// Rewrite config with a different server.
_, entry2 := fakeMCPServerConfig(t, "srv2")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
// Second reload detects the change.
assert.True(t, m.SnapshotChanged([]string{configPath}))
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools2 := m.Tools()
require.Len(t, tools2, 1)
assert.Contains(t, tools2[0].Name, "srv2")
})
t.Run("PerServerConnectFailureUpdatesSnapshot", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
// Config with a nonexistent binary: connect will fail.
path := filepath.Join(dir, ".mcp.json")
data := `{"mcpServers":{"bad":{"command":"/nonexistent/binary","args":[]}}}`
require.NoError(t, os.WriteFile(path, []byte(data), 0o600))
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// Reload should succeed (per-server failures are logged and
// swallowed) and snapshot should update.
err := m.Reload(ctx, []string{path})
require.NoError(t, err)
assert.False(t, m.SnapshotChanged([]string{path}),
"snapshot should be updated even on per-server connect failure")
})
t.Run("EmptyConfigClosesServers", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
require.Len(t, m.Tools(), 1)
// Delete config file.
require.NoError(t, os.Remove(configPath))
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
assert.Empty(t, m.Tools(), "tools should be empty after config deleted")
// Subsequent reload finds snapshot unchanged.
assert.False(t, m.SnapshotChanged([]string{configPath}))
})
}
func TestDifferentialReload(t *testing.T) {
t.Parallel()
// These tests verify differential reload behavior: client
// reuse for unchanged servers, reconnect for changed ones,
// and close for removed ones.
t.Run("UnchangedServerReusesClient", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
// Capture the client pointer.
m.mu.RLock()
origClient := m.servers["srv"].client
m.mu.RUnlock()
require.NotNil(t, origClient)
// Add a new server without changing the existing one.
_, entry2 := fakeMCPServerConfig(t, "srv2")
cfgMap := map[string]mcpServerEntry{"srv": entry, "srv2": entry2}
writeMCPConfig(t, dir, cfgMap)
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
// The unchanged server should reuse the same client.
m.mu.RLock()
newClient := m.servers["srv"].client
m.mu.RUnlock()
assert.Same(t, origClient, newClient,
"unchanged server should reuse client pointer")
// Both servers should have tools.
tools := m.Tools()
require.Len(t, tools, 2)
})
t.Run("ChangedServerGetsNewClient", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
m.mu.RLock()
origClient := m.servers["srv"].client
m.mu.RUnlock()
// Change the server's args to trigger a diff.
entry.Args = append(entry.Args, "-test.v")
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
m.mu.RLock()
newClient := m.servers["srv"].client
m.mu.RUnlock()
assert.NotSame(t, origClient, newClient,
"changed server should get a new client")
})
t.Run("RemovedServerIsClosed", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entryA := fakeMCPServerConfig(t, "srvA")
_, entryB := fakeMCPServerConfig(t, "srvB")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{
"srvA": entryA, "srvB": entryB,
})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
require.Len(t, m.Tools(), 2)
// Capture srvB's client before removal.
m.mu.RLock()
oldClientB := m.servers["srvB"].client
m.mu.RUnlock()
require.NotNil(t, oldClientB)
// Remove srvB from the config.
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srvA": entryA})
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools := m.Tools()
require.Len(t, tools, 1)
assert.Contains(t, tools[0].Name, "srvA")
// The old client for srvB should be closed.
// ListTools on a closed client returns an error.
listCtx, cancel := context.WithTimeout(ctx, testutil.WaitShort)
defer cancel()
_, listErr := oldClientB.ListTools(listCtx, mcp.ListToolsRequest{})
assert.Error(t, listErr, "ListTools on closed client should fail")
})
t.Run("ConnectFailureRetainsOldClient", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
require.Len(t, m.Tools(), 1)
m.mu.RLock()
origClient := m.servers["srv"].client
m.mu.RUnlock()
// Change config to use a bad command, so connect fails.
path := filepath.Join(dir, ".mcp.json")
data := `{"mcpServers":{"srv":{"command":"/nonexistent/binary","args":[]}}}`
require.NoError(t, os.WriteFile(path, []byte(data), 0o600))
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
// The old client should be retained because the new connect
// failed.
m.mu.RLock()
currentClient := m.servers["srv"].client
m.mu.RUnlock()
assert.Same(t, origClient, currentClient,
"failed connect should retain old client")
// Tools should still work.
tools := m.Tools()
require.Len(t, tools, 1)
})
t.Run("PostReloadToolCallReachesKeptServer", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools := m.Tools()
require.Len(t, tools, 1)
toolName := tools[0].Name
// Add a second server (srv unchanged, so client is reused).
_, entry2 := fakeMCPServerConfig(t, "srv2")
writeMCPConfig(t, dir, map[string]mcpServerEntry{
"srv": entry, "srv2": entry2,
})
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
// A tool call to the kept server should reach it.
// The client pointer for "srv" was reused, not replaced.
_, err = m.CallTool(ctx, workspacesdk.CallMCPToolRequest{
ToolName: toolName,
})
// The fake server does not implement tools/call, so we
// expect an error from the server, but the call itself
// should reach the server (not ErrUnknownServer).
require.Error(t, err, "fake server does not implement tools/call")
assert.NotErrorIs(t, err, ErrUnknownServer,
"tool call should reach the server, not fail with unknown server")
})
}
// TestReload_FirstBootPath verifies that the first-boot call site
// (agent.go) can be routed through Reload without behavioral change.
func TestReload_FirstBootPath(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
// Simulate first-boot: Reload with the initial config.
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
tools := m.Tools()
require.Len(t, tools, 1)
assert.Contains(t, tools[0].Name, "echo")
}
// TestReload_NoopWhenUnchanged verifies that Reload returns
// immediately without reconnecting when the snapshot is fresh.
func TestReload_NoopWhenUnchanged(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
m.mu.RLock()
origClient := m.servers["srv"].client
m.mu.RUnlock()
// Second reload with no changes should be a no-op.
err = m.Reload(ctx, []string{configPath})
require.NoError(t, err)
m.mu.RLock()
sameClient := m.servers["srv"].client
m.mu.RUnlock()
assert.Same(t, origClient, sameClient,
"no-op reload should not replace the client")
}
// TestClose_SuppressesSubprocessExitError verifies that Close
// returns nil when servers have running subprocesses that exit
// with a kill signal during shutdown.
func TestClose_SuppressesSubprocessExitError(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
runFakeMCPServer()
return
}
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
dir := t.TempDir()
_, entry := fakeMCPServerConfig(t, "srv")
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
t.Cleanup(func() { _ = m.Close() })
err := m.Reload(ctx, []string{configPath})
require.NoError(t, err)
require.Len(t, m.Tools(), 1, "server should be connected")
// Close kills the subprocess. The ExitError guard should
// suppress the "signal: killed" error.
err = m.Close()
assert.NoError(t, err, "Close should not propagate subprocess kill errors")
}
+4 -2
View File
@@ -249,8 +249,9 @@ func (p *Server) loadCachedWorkspaceContext(
} }
var tools []fantasy.AgentTool var tools []fantasy.AgentTool
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
for _, t := range entry.tools { for _, t := range entry.tools {
tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn)) tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn, invalidate))
} }
return tools return tools
@@ -6290,9 +6291,10 @@ func (p *Server) runChat(
} }
} }
invalidate := func() { p.workspaceMCPToolsCache.Delete(chat.ID) }
for _, t := range toolsResp.Tools { for _, t := range toolsResp.Tools {
workspaceMCPTools = append(workspaceMCPTools, workspaceMCPTools = append(workspaceMCPTools,
chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn, invalidate),
) )
} }
return nil return nil
+23 -5
View File
@@ -4,10 +4,13 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"net/http"
"strings" "strings"
"charm.land/fantasy" "charm.land/fantasy"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
) )
@@ -16,17 +19,22 @@ import (
// connection. It implements fantasy.AgentTool so it can be // connection. It implements fantasy.AgentTool so it can be
// registered alongside built-in chat tools. // registered alongside built-in chat tools.
type WorkspaceMCPTool struct { type WorkspaceMCPTool struct {
info fantasy.ToolInfo info fantasy.ToolInfo
getConn func(context.Context) (workspacesdk.AgentConn, error) getConn func(context.Context) (workspacesdk.AgentConn, error)
providerOpts fantasy.ProviderOptions providerOpts fantasy.ProviderOptions
invalidateCache func()
} }
// NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo // NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo
// discovered on a workspace agent. Each tool proxies calls back // discovered on a workspace agent. Each tool proxies calls back
// through the agent connection. // through the agent connection. The optional invalidateCache
// callback is invoked when CallMCPTool returns a 404 error,
// indicating that the server was removed and the chat's cached
// tool list should be dropped.
func NewWorkspaceMCPTool( func NewWorkspaceMCPTool(
tool workspacesdk.MCPToolInfo, tool workspacesdk.MCPToolInfo,
getConn func(context.Context) (workspacesdk.AgentConn, error), getConn func(context.Context) (workspacesdk.AgentConn, error),
invalidateCache func(),
) *WorkspaceMCPTool { ) *WorkspaceMCPTool {
required := tool.Required required := tool.Required
if required == nil { if required == nil {
@@ -40,7 +48,8 @@ func NewWorkspaceMCPTool(
Required: required, Required: required,
Parallel: true, Parallel: true,
}, },
getConn: getConn, getConn: getConn,
invalidateCache: invalidateCache,
} }
} }
@@ -75,6 +84,15 @@ func (t *WorkspaceMCPTool) Run(
Arguments: args, Arguments: args,
}) })
if err != nil { if err != nil {
// If the agent returns a 404 (ErrUnknownServer), the
// server was removed or renamed. Invalidate the chat's
// cached tool list so the next turn refetches.
var coderErr *codersdk.Error
if errors.As(err, &coderErr) && coderErr.StatusCode() == http.StatusNotFound {
if t.invalidateCache != nil {
t.invalidateCache()
}
}
return fantasy.NewTextErrorResponse(err.Error()), nil return fantasy.NewTextErrorResponse(err.Error()), nil
} }
@@ -0,0 +1,155 @@
package chattool_test
import (
"context"
"net/http"
"sync/atomic"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// fakeAgentConn implements just enough of workspacesdk.AgentConn
// for testing CallMCPTool.
type fakeAgentConn struct {
workspacesdk.AgentConn
callMCPToolFunc func(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error)
}
func (f *fakeAgentConn) CallMCPTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
return f.callMCPToolFunc(ctx, req)
}
func TestWorkspaceMCPTool_InvalidateOn404(t *testing.T) {
t.Parallel()
t.Run("404ErrorInvalidatesCache", func(t *testing.T) {
t.Parallel()
var invalidated atomic.Bool
tool := chattool.NewWorkspaceMCPTool(
workspacesdk.MCPToolInfo{
Name: "test__echo",
Description: "test tool",
},
func(ctx context.Context) (workspacesdk.AgentConn, error) {
return &fakeAgentConn{
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
http.StatusNotFound,
codersdk.Response{
Message: "MCP tool call failed.",
Detail: `unknown MCP server: "test"`,
},
)
},
}, nil
},
func() { invalidated.Store(true) },
)
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
require.NoError(t, err)
assert.True(t, resp.IsError, "response should be an error")
assert.True(t, invalidated.Load(),
"invalidateCache should fire on 404")
})
t.Run("Non404DoesNotInvalidate", func(t *testing.T) {
t.Parallel()
var invalidated atomic.Bool
tool := chattool.NewWorkspaceMCPTool(
workspacesdk.MCPToolInfo{
Name: "test__echo",
Description: "test tool",
},
func(ctx context.Context) (workspacesdk.AgentConn, error) {
return &fakeAgentConn{
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
http.StatusBadGateway,
codersdk.Response{
Message: "Bad Gateway",
},
)
},
}, nil
},
func() { invalidated.Store(true) },
)
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.False(t, invalidated.Load(),
"invalidateCache should NOT fire on non-404 error")
})
t.Run("ToolLevelErrorNoInvalidation", func(t *testing.T) {
t.Parallel()
var invalidated atomic.Bool
tool := chattool.NewWorkspaceMCPTool(
workspacesdk.MCPToolInfo{
Name: "test__echo",
Description: "test tool",
},
func(ctx context.Context) (workspacesdk.AgentConn, error) {
return &fakeAgentConn{
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
return workspacesdk.CallMCPToolResponse{
IsError: true,
Content: []workspacesdk.MCPToolContent{
{Type: "text", Text: "tool error"},
},
}, nil
},
}, nil
},
func() { invalidated.Store(true) },
)
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.False(t, invalidated.Load(),
"invalidateCache should NOT fire on tool-level error (HTTP 200)")
})
t.Run("NilInvalidateCallbackSafe", func(t *testing.T) {
t.Parallel()
tool := chattool.NewWorkspaceMCPTool(
workspacesdk.MCPToolInfo{
Name: "test__echo",
Description: "test tool",
},
func(ctx context.Context) (workspacesdk.AgentConn, error) {
return &fakeAgentConn{
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
http.StatusNotFound,
codersdk.Response{
Message: "MCP tool call failed.",
Detail: `unknown MCP server: "test"`,
},
)
},
}, nil
},
nil,
)
// Should not panic.
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
require.NoError(t, err)
assert.True(t, resp.IsError)
})
}