mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: reload MCP config on change via lazy stat-on-request (#24700)
The MCP manager previously read .mcp.json exactly once at agent startup. Editing the file had no effect until workspace rebuild or agent restart. handleListTools now stats config file mtimes on every tool-list request and triggers a differential reload when any file changed. Unchanged servers keep their client pointer so in-flight tool calls survive. Concurrent reload requests coalesce via singleflight. MCP stdio subprocesses use the agent's execer for resource limits and receive the same enriched environment as SSH sessions via updateEnv. On the chatd side, WorkspaceMCPTool.Run detects 404 responses from CallMCPTool (indicating the server was removed) and drops the chat's cached tool list so the next turn refetches from the agent.
This commit is contained in:
committed by
GitHub
parent
3f0e015fe5
commit
881df9a5b0
+4
-4
@@ -423,14 +423,14 @@ func (a *agent) init() {
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(), nil,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.mcpManager = agentmcp.NewManager(a.logger.Named("mcp"))
|
||||
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager)
|
||||
a.mcpManager = agentmcp.NewManager(a.gracefulCtx, a.logger.Named("mcp"), a.execer, a.updateCommandEnv)
|
||||
a.contextConfigAPI = agentcontextconfig.NewAPI(func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
}, a.contextConfig)
|
||||
a.mcpAPI = agentmcp.NewAPI(a.logger.Named("mcp"), a.mcpManager, a.contextConfigAPI.MCPConfigFiles)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -1413,8 +1413,8 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
|
||||
// lifecycle transition to avoid delaying Ready.
|
||||
// This runs inside the tracked goroutine so it
|
||||
// is properly awaited on shutdown.
|
||||
if mcpErr := a.mcpManager.Connect(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
|
||||
a.logger.Warn(ctx, "failed to connect to workspace MCP servers", slog.Error(mcpErr))
|
||||
if mcpErr := a.mcpManager.Reload(a.gracefulCtx, a.contextConfigAPI.MCPConfigFiles()); mcpErr != nil {
|
||||
a.logger.Warn(ctx, "failed to reload workspace MCP servers", slog.Error(mcpErr))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
+43
-9
@@ -1,6 +1,7 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
@@ -15,16 +16,24 @@ import (
|
||||
// API exposes MCP tool discovery and call proxying through the
|
||||
// agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *Manager
|
||||
logger slog.Logger
|
||||
manager *Manager
|
||||
mcpConfigFiles func() []string
|
||||
}
|
||||
|
||||
// NewAPI creates a new MCP API handler backed by the given
|
||||
// manager.
|
||||
func NewAPI(logger slog.Logger, manager *Manager) *API {
|
||||
// manager. The mcpConfigFiles callback returns the current
|
||||
// resolved config file paths; it is called on every tool-list
|
||||
// request to detect config changes.
|
||||
func NewAPI(
|
||||
logger slog.Logger,
|
||||
manager *Manager,
|
||||
mcpConfigFiles func() []string,
|
||||
) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
manager: manager,
|
||||
mcpConfigFiles: mcpConfigFiles,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,13 +45,38 @@ func (api *API) Routes() http.Handler {
|
||||
return r
|
||||
}
|
||||
|
||||
// handleListTools returns the cached MCP tool definitions,
|
||||
// optionally refreshing them first if ?refresh=true is set.
|
||||
// handleListTools checks whether any .mcp.json config file
|
||||
// has changed since the last reload, triggering a differential
|
||||
// reload if so, then returns the cached MCP tool definitions.
|
||||
// The ?refresh=true query parameter forces a tool re-scan
|
||||
// independent of config changes.
|
||||
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Check config freshness and reload if changed.
|
||||
var reloaded bool
|
||||
paths := api.mcpConfigFiles()
|
||||
if api.manager.SnapshotChanged(paths) {
|
||||
if err := api.manager.Reload(ctx, paths); err != nil {
|
||||
// Categorize the error for operator debugging.
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err))
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err))
|
||||
default:
|
||||
api.logger.Warn(ctx, "mcp reload failed", slog.Error(err))
|
||||
}
|
||||
// Fall through to return whatever tools we have.
|
||||
} else {
|
||||
reloaded = true
|
||||
}
|
||||
}
|
||||
|
||||
// Allow callers to force a tool re-scan before listing.
|
||||
if r.URL.Query().Get("refresh") == "true" {
|
||||
// Skip if a config reload ran above, since it already
|
||||
// refreshes tools as part of the reload.
|
||||
if r.URL.Query().Get("refresh") == "true" && !reloaded {
|
||||
if err := api.manager.RefreshTools(ctx); err != nil {
|
||||
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestHandleListTools_ReloadOnChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
// Cases that share the single-request-and-check pattern.
|
||||
type singleRequestCase struct {
|
||||
name string
|
||||
entries func(t *testing.T) map[string]mcpServerEntry
|
||||
reloadManager bool
|
||||
closeManager bool
|
||||
expectedTools int
|
||||
toolNameContains string
|
||||
}
|
||||
|
||||
cases := []singleRequestCase{
|
||||
{
|
||||
name: "InitialRequestNoReload",
|
||||
entries: func(t *testing.T) map[string]mcpServerEntry {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
return map[string]mcpServerEntry{"srv": entry}
|
||||
},
|
||||
reloadManager: true,
|
||||
expectedTools: 1,
|
||||
toolNameContains: "echo",
|
||||
},
|
||||
{
|
||||
name: "ManagerClosedReturnsEmpty",
|
||||
entries: func(_ *testing.T) map[string]mcpServerEntry {
|
||||
return map[string]mcpServerEntry{}
|
||||
},
|
||||
closeManager: true,
|
||||
expectedTools: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
configPath := writeMCPConfig(t, dir, tc.entries(t))
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
if tc.closeManager {
|
||||
require.NoError(t, m.Close())
|
||||
} else {
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
}
|
||||
|
||||
if tc.reloadManager {
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
api := NewAPI(logger, m, func() []string {
|
||||
return []string{configPath}
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/tools", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp workspacesdk.ListMCPToolsResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
require.Len(t, resp.Tools, tc.expectedTools)
|
||||
if tc.toolNameContains != "" {
|
||||
assert.Contains(t, resp.Tools[0].Name, tc.toolNameContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ConfigChangeTriggersReload has a mutate-then-re-request flow
|
||||
// that does not fit the single-request table pattern.
|
||||
t.Run("ConfigChangeTriggersReload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry1 := fakeMCPServerConfig(t, "srv1")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
api := NewAPI(logger, m, func() []string {
|
||||
return []string{configPath}
|
||||
})
|
||||
|
||||
// Verify initial tools.
|
||||
req := httptest.NewRequest(http.MethodGet, "/tools", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp1 workspacesdk.ListMCPToolsResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp1))
|
||||
require.Len(t, resp1.Tools, 1)
|
||||
assert.Contains(t, resp1.Tools[0].Name, "srv1")
|
||||
|
||||
// Mutate the config file.
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
|
||||
|
||||
// Next request should trigger a reload and return new tools.
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/tools", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
|
||||
var resp2 workspacesdk.ListMCPToolsResponse
|
||||
require.NoError(t, json.NewDecoder(rec2.Body).Decode(&resp2))
|
||||
require.Len(t, resp2.Tools, 1)
|
||||
assert.Contains(t, resp2.Tools[0].Name, "srv2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleListTools_RefreshParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("RefreshTrueUnchangedSnapshot", func(t *testing.T) {
|
||||
// Exercises the ?refresh=true code path when the config
|
||||
// snapshot is unchanged. Verifies the endpoint returns
|
||||
// tools without error.
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
api := NewAPI(logger, m, func() []string {
|
||||
return []string{configPath}
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp workspacesdk.ListMCPToolsResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
// Tool should still be present after refresh.
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Contains(t, resp.Tools[0].Name, "echo")
|
||||
})
|
||||
|
||||
t.Run("RefreshTrueWithChangedConfig", func(t *testing.T) {
|
||||
// Exercises the ?refresh=true code path when the config
|
||||
// has also changed. The reload path already calls
|
||||
// RefreshTools, so the handler skips the redundant call.
|
||||
// This test covers the branch; it cannot observe the
|
||||
// skip without a mock.
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry1 := fakeMCPServerConfig(t, "srv1")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
api := NewAPI(logger, m, func() []string {
|
||||
return []string{configPath}
|
||||
})
|
||||
|
||||
// Mutate config.
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/tools?refresh=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
api.Routes().ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp workspacesdk.ListMCPToolsResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
require.Len(t, resp.Tools, 1)
|
||||
assert.Contains(t, resp.Tools[0].Name, "srv2")
|
||||
})
|
||||
}
|
||||
+491
-198
@@ -5,7 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -16,8 +19,11 @@ import (
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
tailscalesingleflight "tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/usershell"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
@@ -44,15 +50,30 @@ var (
|
||||
ErrUnknownServer = xerrors.New("unknown MCP server")
|
||||
)
|
||||
|
||||
// fileSnapshot records the identity of a config file at the time
|
||||
// it was last read.
|
||||
type fileSnapshot struct {
|
||||
exists bool
|
||||
modTime time.Time
|
||||
size int64
|
||||
}
|
||||
|
||||
// Manager manages connections to MCP servers discovered from a
|
||||
// workspace's .mcp.json file. It caches the aggregated tool list
|
||||
// and proxies tool calls to the appropriate server.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
logger slog.Logger
|
||||
closed bool
|
||||
servers map[string]*serverEntry // keyed by server name
|
||||
tools []workspacesdk.MCPToolInfo
|
||||
ctx context.Context
|
||||
execer agentexec.Execer
|
||||
updateEnv func(current []string) ([]string, error)
|
||||
|
||||
mu sync.RWMutex
|
||||
logger slog.Logger
|
||||
closed bool
|
||||
servers map[string]*serverEntry
|
||||
tools []workspacesdk.MCPToolInfo
|
||||
snapshot map[string]fileSnapshot
|
||||
serverGen uint64
|
||||
sf tailscalesingleflight.Group[string, struct{}]
|
||||
}
|
||||
|
||||
// serverEntry pairs a server config with its connected client.
|
||||
@@ -61,18 +82,189 @@ type serverEntry struct {
|
||||
client *client.Client
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP client manager.
|
||||
func NewManager(logger slog.Logger) *Manager {
|
||||
// NewManager creates a new MCP client manager. The ctx bounds
|
||||
// subprocess lifetime. The execer applies resource limits to
|
||||
// MCP server subprocesses. The updateEnv callback enriches the
|
||||
// subprocess environment to match interactive sessions.
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
updateEnv func([]string) ([]string, error),
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
servers: make(map[string]*serverEntry),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
updateEnv: updateEnv,
|
||||
servers: make(map[string]*serverEntry),
|
||||
snapshot: make(map[string]fileSnapshot),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect reads MCP config files at the given absolute paths and
|
||||
// connects to all configured servers. Failed servers are logged
|
||||
// and skipped. Missing config files are silently skipped.
|
||||
func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
|
||||
// Reload checks whether config files have changed and, if so,
|
||||
// performs a differential reconnect. Concurrent callers are
|
||||
// coalesced via singleflight; the reload body runs under the
|
||||
// Manager's lifetime context so it survives caller cancellation.
|
||||
func (m *Manager) Reload(ctx context.Context, paths []string) error {
|
||||
m.mu.RLock()
|
||||
closed := m.closed
|
||||
hasSnapshot := len(m.snapshot) > 0
|
||||
m.mu.RUnlock()
|
||||
if closed {
|
||||
return xerrors.New("manager closed")
|
||||
}
|
||||
|
||||
// Double-check: another goroutine may have completed a
|
||||
// reload between the caller's SnapshotChanged and this
|
||||
// call. The singleflight body uses its own resolved paths.
|
||||
if hasSnapshot && !m.SnapshotChanged(paths) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// All concurrent callers share one in-flight reload keyed
|
||||
// by "". If a concurrent caller resolves different paths
|
||||
// (e.g. after a manifest reconnect), its paths are not
|
||||
// consulted; the next SnapshotChanged check after this
|
||||
// reload completes will detect the mismatch and trigger
|
||||
// a fresh reload.
|
||||
ch := m.sf.DoChan("reload", func() (struct{}, error) {
|
||||
err := m.doReload(m.ctx, paths)
|
||||
return struct{}{}, err
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case res := <-ch:
|
||||
return res.Err
|
||||
}
|
||||
}
|
||||
|
||||
// SnapshotChanged checks whether any config file has changed
|
||||
// since the last reload by comparing os.Stat results against
|
||||
// the stored snapshot.
|
||||
func (m *Manager) SnapshotChanged(paths []string) bool {
|
||||
seen := make(map[string]struct{}, len(paths))
|
||||
unique := make([]string, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
if _, ok := seen[p]; !ok {
|
||||
seen[p] = struct{}{}
|
||||
unique = append(unique, p)
|
||||
}
|
||||
}
|
||||
paths = unique
|
||||
|
||||
m.mu.RLock()
|
||||
snap := maps.Clone(m.snapshot)
|
||||
snapshotLen := len(snap)
|
||||
m.mu.RUnlock()
|
||||
|
||||
if len(paths) != snapshotLen {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
prev, ok := snap[p]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
// Stat failed; changed only if the file existed before.
|
||||
if prev.exists {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Stat succeeded but file was absent before: it appeared.
|
||||
if !prev.exists {
|
||||
return true
|
||||
}
|
||||
|
||||
if !info.ModTime().Equal(prev.modTime) || info.Size() != prev.size {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// serverDiff is the output of classifyServers: which servers to
|
||||
// connect, which to close, which to keep, and a snapshot of the
|
||||
// previous map for fallback on connect failure.
|
||||
type serverDiff struct {
|
||||
toConnect []ServerConfig
|
||||
toClose []*serverEntry
|
||||
keep map[string]*serverEntry
|
||||
prev map[string]*serverEntry
|
||||
}
|
||||
|
||||
type connectedServer struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
client *client.Client
|
||||
}
|
||||
|
||||
// doReload reads MCP config files and performs a differential
|
||||
// reconnect. Unchanged servers keep their existing client; new or
|
||||
// changed servers get a fresh connection; removed servers are
|
||||
// closed.
|
||||
func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
|
||||
allConfigs, snap := m.parseAndDedup(ctx, mcpConfigFiles)
|
||||
|
||||
wanted := make(map[string]ServerConfig, len(allConfigs))
|
||||
for _, cfg := range allConfigs {
|
||||
wanted[cfg.Name] = cfg
|
||||
}
|
||||
|
||||
diff, err := m.classifyServers(wanted)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connected := m.connectAll(ctx, diff.toConnect)
|
||||
|
||||
replaced, err := m.installServers(wanted, diff, connected, snap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close removed and replaced servers outside the lock to
|
||||
// avoid leaking child processes and to avoid blocking
|
||||
// concurrent readers on subprocess I/O.
|
||||
// Note: a concurrent CallTool that captured a removed
|
||||
// entry's client before the swap may call a closed client.
|
||||
// This is a narrow race that self-heals on the next request.
|
||||
for _, entry := range diff.toClose {
|
||||
_ = entry.client.Close()
|
||||
}
|
||||
for _, entry := range replaced {
|
||||
_ = entry.client.Close()
|
||||
}
|
||||
|
||||
// Refresh tools outside the lock to avoid blocking
|
||||
// concurrent reads during network I/O.
|
||||
if err := m.RefreshTools(ctx); err != nil {
|
||||
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndDedup reads all config files and returns a deduplicated
|
||||
// list of server configs. Missing files are silently skipped;
|
||||
// parse errors are logged and skipped.
|
||||
func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) {
|
||||
// Stat before reading so the snapshot is conservatively old.
|
||||
// If a file changes between stat and read, the snapshot
|
||||
// records the old mtime, SnapshotChanged detects a mismatch
|
||||
// on the next check, and triggers a re-read. False positives
|
||||
// (extra reload) are safe; false negatives (missed change)
|
||||
// are not.
|
||||
snap := captureSnapshot(mcpConfigFiles)
|
||||
|
||||
var allConfigs []ServerConfig
|
||||
for _, configPath := range mcpConfigFiles {
|
||||
configs, err := ParseConfig(configPath)
|
||||
@@ -99,26 +291,55 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
|
||||
seen[cfg.Name] = struct{}{}
|
||||
deduped = append(deduped, cfg)
|
||||
}
|
||||
allConfigs = deduped
|
||||
return deduped, snap
|
||||
}
|
||||
|
||||
if len(allConfigs) == 0 {
|
||||
return nil
|
||||
// classifyServers compares wanted configs against the current
|
||||
// server map and returns a diff describing what changed.
|
||||
// Acquires and releases m.mu for reading.
|
||||
func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, xerrors.New("manager closed")
|
||||
}
|
||||
|
||||
// Connect to servers in parallel without holding the
|
||||
// lock, since each connectServer call may block on
|
||||
// network I/O for up to connectTimeout.
|
||||
type connectedServer struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
client *client.Client
|
||||
diff := &serverDiff{
|
||||
keep: make(map[string]*serverEntry),
|
||||
}
|
||||
|
||||
for name, wantCfg := range wanted {
|
||||
if existing, ok := m.servers[name]; ok {
|
||||
if reflect.DeepEqual(existing.config, wantCfg) {
|
||||
diff.keep[name] = existing
|
||||
} else {
|
||||
diff.toConnect = append(diff.toConnect, wantCfg)
|
||||
}
|
||||
} else {
|
||||
diff.toConnect = append(diff.toConnect, wantCfg)
|
||||
}
|
||||
}
|
||||
|
||||
for name, entry := range m.servers {
|
||||
if _, ok := wanted[name]; !ok {
|
||||
diff.toClose = append(diff.toClose, entry)
|
||||
}
|
||||
}
|
||||
|
||||
diff.prev = maps.Clone(m.servers)
|
||||
return diff, nil
|
||||
}
|
||||
|
||||
// connectAll runs connectServer in parallel for the given configs.
|
||||
// Failed connects are logged and skipped.
|
||||
func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
connected []connectedServer
|
||||
)
|
||||
var eg errgroup.Group
|
||||
for _, cfg := range allConfigs {
|
||||
for _, cfg := range toConnect {
|
||||
eg.Go(func() error {
|
||||
c, err := m.connectServer(ctx, cfg)
|
||||
if err != nil {
|
||||
@@ -138,131 +359,81 @@ func (m *Manager) Connect(ctx context.Context, mcpConfigFiles []string) error {
|
||||
})
|
||||
}
|
||||
_ = eg.Wait()
|
||||
return connected
|
||||
}
|
||||
|
||||
// installServers builds the new server map from diff.keep and the
|
||||
// connected list, falling back to diff.prev when a connect failed.
|
||||
// Returns old entries replaced by successful connects (caller
|
||||
// closes them). Acquires and releases m.mu.
|
||||
func (m *Manager) installServers(
|
||||
wanted map[string]ServerConfig,
|
||||
diff *serverDiff,
|
||||
connected []connectedServer,
|
||||
snap map[string]fileSnapshot,
|
||||
) ([]*serverEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
// Close the freshly-connected clients since we're
|
||||
// shutting down.
|
||||
for _, cs := range connected {
|
||||
_ = cs.client.Close()
|
||||
}
|
||||
return xerrors.New("manager closed")
|
||||
return nil, xerrors.New("manager closed")
|
||||
}
|
||||
|
||||
// Close previous connections to avoid leaking child
|
||||
// processes on agent reconnect.
|
||||
for _, entry := range m.servers {
|
||||
_ = entry.client.Close()
|
||||
}
|
||||
m.servers = make(map[string]*serverEntry, len(connected))
|
||||
|
||||
newConnected := make(map[string]connectedServer, len(connected))
|
||||
for _, cs := range connected {
|
||||
m.servers[cs.name] = &serverEntry{
|
||||
config: cs.config,
|
||||
client: cs.client,
|
||||
newConnected[cs.name] = cs
|
||||
}
|
||||
|
||||
newServers := make(map[string]*serverEntry, len(wanted))
|
||||
for name, entry := range diff.keep {
|
||||
newServers[name] = entry
|
||||
}
|
||||
|
||||
var replaced []*serverEntry
|
||||
for name, wantCfg := range wanted {
|
||||
if _, kept := diff.keep[name]; kept {
|
||||
continue
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Refresh tools outside the lock to avoid blocking
|
||||
// concurrent reads during network I/O.
|
||||
if err := m.RefreshTools(ctx); err != nil {
|
||||
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectServer establishes a connection to a single MCP server
|
||||
// and returns the connected client. It does not modify any Manager
|
||||
// state.
|
||||
func (*Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
|
||||
tr, err := createTransport(cfg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
c := client.NewClient(tr)
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Use the parent ctx (not connectCtx) so the subprocess outlives
|
||||
// the connect/initialize handshake. connectCtx bounds only the
|
||||
// Initialize call below. The subprocess is cleaned up when the
|
||||
// Manager is closed or ctx is canceled.
|
||||
if err := c.Start(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "coder-agent",
|
||||
Version: buildinfo.Version(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// createTransport builds the mcp-go transport for a server config.
|
||||
func createTransport(cfg ServerConfig) (transport.Interface, error) {
|
||||
switch cfg.Transport {
|
||||
case "stdio":
|
||||
return transport.NewStdio(
|
||||
cfg.Command,
|
||||
buildEnv(cfg.Env),
|
||||
cfg.Args...,
|
||||
), nil
|
||||
case "http", "":
|
||||
return transport.NewStreamableHTTP(
|
||||
cfg.URL,
|
||||
transport.WithHTTPHeaders(cfg.Headers),
|
||||
)
|
||||
case "sse":
|
||||
return transport.NewSSE(
|
||||
cfg.URL,
|
||||
transport.WithHeaders(cfg.Headers),
|
||||
)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEnv merges the current process environment with explicit
|
||||
// overrides, returning the result as KEY=VALUE strings suitable
|
||||
// for the stdio transport.
|
||||
func buildEnv(explicit map[string]string) []string {
|
||||
env := os.Environ()
|
||||
if len(explicit) == 0 {
|
||||
return env
|
||||
}
|
||||
|
||||
// Index existing env so explicit keys can override in-place.
|
||||
existing := make(map[string]int, len(env))
|
||||
for i, kv := range env {
|
||||
if k, _, ok := strings.Cut(kv, "="); ok {
|
||||
existing[k] = i
|
||||
if cs, ok := newConnected[wantCfg.Name]; ok {
|
||||
newServers[wantCfg.Name] = &serverEntry{
|
||||
config: cs.config,
|
||||
client: cs.client,
|
||||
}
|
||||
if prev, existed := diff.prev[wantCfg.Name]; existed {
|
||||
replaced = append(replaced, prev)
|
||||
}
|
||||
} else if prev, existed := diff.prev[wantCfg.Name]; existed {
|
||||
// Connect failed; retain the old client.
|
||||
newServers[wantCfg.Name] = prev
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range explicit {
|
||||
entry := k + "=" + v
|
||||
if idx, ok := existing[k]; ok {
|
||||
env[idx] = entry
|
||||
} else {
|
||||
env = append(env, entry)
|
||||
m.servers = newServers
|
||||
m.serverGen++
|
||||
m.snapshot = snap
|
||||
return replaced, nil
|
||||
}
|
||||
|
||||
// captureSnapshot stats each path and returns the current
|
||||
// snapshot map.
|
||||
func captureSnapshot(paths []string) map[string]fileSnapshot {
|
||||
snap := make(map[string]fileSnapshot, len(paths))
|
||||
for _, p := range paths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
snap[p] = fileSnapshot{exists: false}
|
||||
continue
|
||||
}
|
||||
snap[p] = fileSnapshot{
|
||||
exists: true,
|
||||
modTime: info.ModTime(),
|
||||
size: info.Size(),
|
||||
}
|
||||
}
|
||||
return env
|
||||
return snap
|
||||
}
|
||||
|
||||
// Tools returns the cached tool list. Thread-safe.
|
||||
@@ -304,68 +475,6 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ
|
||||
return convertResult(result), nil
|
||||
}
|
||||
|
||||
// splitToolName extracts the server name and original tool name
|
||||
// from a prefixed tool name like "server__tool".
|
||||
func splitToolName(prefixed string) (serverName, toolName string, err error) {
|
||||
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
|
||||
if !ok || server == "" || tool == "" {
|
||||
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
|
||||
}
|
||||
return server, tool, nil
|
||||
}
|
||||
|
||||
// convertResult translates an MCP CallToolResult into a
|
||||
// workspacesdk.CallMCPToolResponse. It iterates over content
|
||||
// items and maps each recognized type.
|
||||
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
|
||||
if result == nil {
|
||||
return workspacesdk.CallMCPToolResponse{}
|
||||
}
|
||||
|
||||
var content []workspacesdk.MCPToolContent
|
||||
for _, item := range result.Content {
|
||||
switch c := item.(type) {
|
||||
case mcp.TextContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: c.Text,
|
||||
})
|
||||
case mcp.ImageContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "image",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.AudioContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "audio",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.EmbeddedResource:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
|
||||
})
|
||||
case mcp.ResourceLink:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[resource link: %s]", c.URI),
|
||||
})
|
||||
default:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("[unsupported content type: %T]", item),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return workspacesdk.CallMCPToolResponse{
|
||||
Content: content,
|
||||
IsError: result.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshTools re-fetches tool lists from all connected servers
|
||||
// in parallel and rebuilds the cache. On partial failure, tools
|
||||
// from servers that responded successfully are merged with the
|
||||
@@ -378,6 +487,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
|
||||
for k, v := range m.servers {
|
||||
servers[k] = v
|
||||
}
|
||||
gen := m.serverGen
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Fetch tool lists in parallel without holding any lock.
|
||||
@@ -451,7 +561,12 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
|
||||
})
|
||||
|
||||
m.mu.Lock()
|
||||
m.tools = merged
|
||||
// Skip the write if the server map changed since the
|
||||
// snapshot. A doReload that bumped the generation will
|
||||
// produce a correct tool list; this write would be stale.
|
||||
if m.serverGen == gen {
|
||||
m.tools = merged
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
@@ -466,9 +581,187 @@ func (m *Manager) Close() error {
|
||||
m.closed = true
|
||||
var errs []error
|
||||
for _, entry := range m.servers {
|
||||
errs = append(errs, entry.client.Close())
|
||||
if err := entry.client.Close(); err != nil {
|
||||
// Subprocess kill signals are expected during shutdown.
|
||||
// The stdio transport returns cmd.Wait() which surfaces
|
||||
// "signal: killed" as an exec.ExitError.
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(err, &exitErr) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.servers = make(map[string]*serverEntry)
|
||||
m.tools = nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// connectServer establishes a connection to a single MCP server
|
||||
// and returns the connected client. It does not modify any Manager
|
||||
// state.
|
||||
func (m *Manager) connectServer(ctx context.Context, cfg ServerConfig) (*client.Client, error) {
|
||||
tr, err := m.createTransport(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create transport for %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
c := client.NewClient(tr)
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Use the parent ctx (not connectCtx) so the subprocess outlives
|
||||
// the connect/initialize handshake. connectCtx bounds only the
|
||||
// Initialize call below. The subprocess is cleaned up when the
|
||||
// Manager is closed or ctx is canceled.
|
||||
if err := c.Start(ctx); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("start %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
_, err = c.Initialize(connectCtx, mcp.InitializeRequest{
|
||||
Params: mcp.InitializeParams{
|
||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||
ClientInfo: mcp.Implementation{
|
||||
Name: "coder-agent",
|
||||
Version: buildinfo.Version(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, xerrors.Errorf("initialize %q: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// createTransport builds the mcp-go transport for a server config.
|
||||
func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transport.Interface, error) {
|
||||
switch cfg.Transport {
|
||||
case "stdio":
|
||||
env := m.buildEnv(ctx, cfg.Env)
|
||||
return transport.NewStdioWithOptions(
|
||||
cfg.Command,
|
||||
env,
|
||||
cfg.Args,
|
||||
transport.WithCommandFunc(func(ctx context.Context, command string, cmdEnv []string, args []string) (*exec.Cmd, error) {
|
||||
cmd := m.execer.CommandContext(ctx, command, args...)
|
||||
cmd.Env = cmdEnv
|
||||
return cmd, nil
|
||||
}),
|
||||
), nil
|
||||
case "http", "":
|
||||
return transport.NewStreamableHTTP(
|
||||
cfg.URL,
|
||||
transport.WithHTTPHeaders(cfg.Headers),
|
||||
)
|
||||
case "sse":
|
||||
return transport.NewSSE(
|
||||
cfg.URL,
|
||||
transport.WithHeaders(cfg.Headers),
|
||||
)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
// buildEnv enriches the process environment via the agent's
|
||||
// updateEnv callback, then merges explicit overrides from the
|
||||
// server config on top.
|
||||
func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string {
|
||||
env := usershell.SystemEnvInfo{}.Environ()
|
||||
if m.updateEnv != nil {
|
||||
var err error
|
||||
env, err = m.updateEnv(env)
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "failed to enrich MCP server environment",
|
||||
slog.Error(err),
|
||||
)
|
||||
env = usershell.SystemEnvInfo{}.Environ()
|
||||
}
|
||||
}
|
||||
if len(explicit) == 0 {
|
||||
return env
|
||||
}
|
||||
|
||||
// Index existing env so explicit keys can override in-place.
|
||||
existing := make(map[string]int, len(env))
|
||||
for i, kv := range env {
|
||||
if k, _, ok := strings.Cut(kv, "="); ok {
|
||||
existing[k] = i
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range explicit {
|
||||
entry := k + "=" + v
|
||||
if idx, ok := existing[k]; ok {
|
||||
env[idx] = entry
|
||||
} else {
|
||||
env = append(env, entry)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// splitToolName extracts the server name and original tool name
|
||||
// from a prefixed tool name like "server__tool".
|
||||
func splitToolName(prefixed string) (serverName, toolName string, err error) {
|
||||
server, tool, ok := strings.Cut(prefixed, ToolNameSep)
|
||||
if !ok || server == "" || tool == "" {
|
||||
return "", "", xerrors.Errorf("%w: expected format \"server%stool\", got %q", ErrInvalidToolName, ToolNameSep, prefixed)
|
||||
}
|
||||
return server, tool, nil
|
||||
}
|
||||
|
||||
// convertResult translates an MCP CallToolResult into a
|
||||
// workspacesdk.CallMCPToolResponse. It iterates over content
|
||||
// items and maps each recognized type.
|
||||
func convertResult(result *mcp.CallToolResult) workspacesdk.CallMCPToolResponse {
|
||||
if result == nil {
|
||||
return workspacesdk.CallMCPToolResponse{}
|
||||
}
|
||||
|
||||
var content []workspacesdk.MCPToolContent
|
||||
for _, item := range result.Content {
|
||||
switch c := item.(type) {
|
||||
case mcp.TextContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: c.Text,
|
||||
})
|
||||
case mcp.ImageContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "image",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.AudioContent:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "audio",
|
||||
Data: c.Data,
|
||||
MediaType: c.MIMEType,
|
||||
})
|
||||
case mcp.EmbeddedResource:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[embedded resource: %T]", c.Resource),
|
||||
})
|
||||
case mcp.ResourceLink:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "resource",
|
||||
Text: fmt.Sprintf("[resource link: %s]", c.URI),
|
||||
})
|
||||
default:
|
||||
content = append(content, workspacesdk.MCPToolContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("[unsupported content type: %T]", item),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return workspacesdk.CallMCPToolResponse{
|
||||
Content: content,
|
||||
IsError: result.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -227,7 +228,7 @@ func TestConnectServer_StdioProcessSurvivesConnect(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
m := &Manager{}
|
||||
m := &Manager{execer: agentexec.DefaultExecer}
|
||||
client, err := m.connectServer(ctx, cfg)
|
||||
require.NoError(t, err, "connectServer should succeed")
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
@@ -0,0 +1,708 @@
|
||||
package agentmcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// writeMCPConfig writes a .mcp.json file with the given server
|
||||
// entries. Each entry maps a server name to its config.
|
||||
func writeMCPConfig(t *testing.T, dir string, servers map[string]mcpServerEntry) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
cfg := mcpConfigFile{MCPServers: make(map[string]json.RawMessage)}
|
||||
for name, entry := range servers {
|
||||
raw, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
cfg.MCPServers[name] = raw
|
||||
}
|
||||
data, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(path, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
return path
|
||||
}
|
||||
|
||||
// fakeMCPServerConfig returns a ServerConfig that launches a fake
|
||||
// MCP server using the test binary re-exec pattern.
|
||||
func fakeMCPServerConfig(t *testing.T, name string) (ServerConfig, mcpServerEntry) {
|
||||
t.Helper()
|
||||
testBin, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
cfg := ServerConfig{
|
||||
Name: name,
|
||||
Transport: "stdio",
|
||||
Command: testBin,
|
||||
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
|
||||
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
|
||||
}
|
||||
entry := mcpServerEntry{
|
||||
Command: testBin,
|
||||
Args: []string{"-test.run=^TestConnectServer_StdioProcessSurvivesConnect$"},
|
||||
Env: map[string]string{"TEST_MCP_FAKE_SERVER": "1"},
|
||||
}
|
||||
return cfg, entry
|
||||
}
|
||||
|
||||
func TestSnapshotChanged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
setup func(t *testing.T, dir string) []string
|
||||
mutate func(t *testing.T, dir string)
|
||||
checkPaths func(t *testing.T, dir string, initialPaths []string) []string
|
||||
want bool
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "UnchangedFiles",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
return []string{configPath}
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ContentChange",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
return []string{configPath}
|
||||
},
|
||||
mutate: func(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FileBecomesMissing",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
return []string{configPath}
|
||||
},
|
||||
mutate: func(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, ".mcp.json")))
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FileAppears",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
return []string{filepath.Join(dir, ".mcp.json")}
|
||||
},
|
||||
mutate: func(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "BothAbsentUnchanged",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
return []string{filepath.Join(dir, ".mcp.json")}
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PathSetDiffers",
|
||||
setup: func(t *testing.T, dir string) []string {
|
||||
t.Helper()
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
return []string{configPath}
|
||||
},
|
||||
checkPaths: func(t *testing.T, dir string, initialPaths []string) []string {
|
||||
t.Helper()
|
||||
extraPath := filepath.Join(dir, "extra.mcp.json")
|
||||
return append(initialPaths, extraPath)
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
paths := tc.setup(t, dir)
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.mutate != nil {
|
||||
tc.mutate(t, dir)
|
||||
}
|
||||
|
||||
checkPaths := paths
|
||||
if tc.checkPaths != nil {
|
||||
checkPaths = tc.checkPaths(t, dir, paths)
|
||||
}
|
||||
|
||||
changed := m.SnapshotChanged(checkPaths)
|
||||
assert.Equal(t, tc.want, changed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotChanged_MultipleConfigFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
dir1 := t.TempDir()
|
||||
dir2 := t.TempDir()
|
||||
|
||||
_, entry1 := fakeMCPServerConfig(t, "srv1")
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
path1 := writeMCPConfig(t, dir1, map[string]mcpServerEntry{"srv1": entry1})
|
||||
path2 := writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2": entry2})
|
||||
paths := []string{path1, path2}
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Initial reload with both config files.
|
||||
err := m.Reload(ctx, paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both files unchanged.
|
||||
assert.False(t, m.SnapshotChanged(paths),
|
||||
"snapshot should not change when both files are unchanged")
|
||||
|
||||
// Mutate only the second file.
|
||||
_, entry2b := fakeMCPServerConfig(t, "srv2b")
|
||||
writeMCPConfig(t, dir2, map[string]mcpServerEntry{"srv2b": entry2b})
|
||||
|
||||
assert.True(t, m.SnapshotChanged(paths),
|
||||
"snapshot should change when second file is mutated")
|
||||
|
||||
// Reload picks up the mutation.
|
||||
err = m.Reload(ctx, paths)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tools from both files should be present.
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 2, "should have tools from both config files")
|
||||
assert.Contains(t, tools[0].Name, "srv1",
|
||||
"first tool should be from first config")
|
||||
assert.Contains(t, tools[1].Name, "srv2b",
|
||||
"second tool should be from second config")
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SingleReloadUpdatesSnapshot", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1, "should have one tool from the fake server")
|
||||
assert.Contains(t, tools[0].Name, "echo")
|
||||
|
||||
// Snapshot should be fresh.
|
||||
assert.False(t, m.SnapshotChanged([]string{configPath}))
|
||||
})
|
||||
|
||||
t.Run("ReloadAfterClose", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, m.Close())
|
||||
|
||||
err := m.Reload(ctx, []string{"/nonexistent"})
|
||||
require.Error(t, err, "reload after close should fail")
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReloadsCoalesce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Launch multiple concurrent reloads.
|
||||
const numCallers = 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, numCallers)
|
||||
for i := range numCallers {
|
||||
wg.Go(func() {
|
||||
errs[i] = m.Reload(ctx, []string{configPath})
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, err := range errs {
|
||||
assert.NoError(t, err, "caller %d should not fail", i)
|
||||
}
|
||||
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1)
|
||||
})
|
||||
|
||||
t.Run("CallerContextCanceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mgrCtx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(mgrCtx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Use an already-canceled caller context.
|
||||
callerCtx, cancel := context.WithCancel(mgrCtx)
|
||||
cancel() // Cancel immediately.
|
||||
|
||||
err := m.Reload(callerCtx, []string{configPath})
|
||||
// The caller context is already canceled, so Reload should
|
||||
// return the caller's context error.
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
})
|
||||
|
||||
t.Run("SequentialReloadsDiffDetect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry1 := fakeMCPServerConfig(t, "srv1")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv1": entry1})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// First reload.
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
tools1 := m.Tools()
|
||||
require.Len(t, tools1, 1)
|
||||
assert.Contains(t, tools1[0].Name, "srv1")
|
||||
|
||||
// Rewrite config with a different server.
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv2": entry2})
|
||||
|
||||
// Second reload detects the change.
|
||||
assert.True(t, m.SnapshotChanged([]string{configPath}))
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
tools2 := m.Tools()
|
||||
require.Len(t, tools2, 1)
|
||||
assert.Contains(t, tools2[0].Name, "srv2")
|
||||
})
|
||||
|
||||
t.Run("PerServerConnectFailureUpdatesSnapshot", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
// Config with a nonexistent binary: connect will fail.
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
data := `{"mcpServers":{"bad":{"command":"/nonexistent/binary","args":[]}}}`
|
||||
require.NoError(t, os.WriteFile(path, []byte(data), 0o600))
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Reload should succeed (per-server failures are logged and
|
||||
// swallowed) and snapshot should update.
|
||||
err := m.Reload(ctx, []string{path})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, m.SnapshotChanged([]string{path}),
|
||||
"snapshot should be updated even on per-server connect failure")
|
||||
})
|
||||
|
||||
t.Run("EmptyConfigClosesServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.Tools(), 1)
|
||||
|
||||
// Delete config file.
|
||||
require.NoError(t, os.Remove(configPath))
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, m.Tools(), "tools should be empty after config deleted")
|
||||
|
||||
// Subsequent reload finds snapshot unchanged.
|
||||
assert.False(t, m.SnapshotChanged([]string{configPath}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDifferentialReload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// These tests verify differential reload behavior: client
|
||||
// reuse for unchanged servers, reconnect for changed ones,
|
||||
// and close for removed ones.
|
||||
|
||||
t.Run("UnchangedServerReusesClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Capture the client pointer.
|
||||
m.mu.RLock()
|
||||
origClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
require.NotNil(t, origClient)
|
||||
|
||||
// Add a new server without changing the existing one.
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
cfgMap := map[string]mcpServerEntry{"srv": entry, "srv2": entry2}
|
||||
writeMCPConfig(t, dir, cfgMap)
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The unchanged server should reuse the same client.
|
||||
m.mu.RLock()
|
||||
newClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
assert.Same(t, origClient, newClient,
|
||||
"unchanged server should reuse client pointer")
|
||||
|
||||
// Both servers should have tools.
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 2)
|
||||
})
|
||||
|
||||
t.Run("ChangedServerGetsNewClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mu.RLock()
|
||||
origClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Change the server's args to trigger a diff.
|
||||
entry.Args = append(entry.Args, "-test.v")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mu.RLock()
|
||||
newClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
assert.NotSame(t, origClient, newClient,
|
||||
"changed server should get a new client")
|
||||
})
|
||||
|
||||
t.Run("RemovedServerIsClosed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entryA := fakeMCPServerConfig(t, "srvA")
|
||||
_, entryB := fakeMCPServerConfig(t, "srvB")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{
|
||||
"srvA": entryA, "srvB": entryB,
|
||||
})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.Tools(), 2)
|
||||
|
||||
// Capture srvB's client before removal.
|
||||
m.mu.RLock()
|
||||
oldClientB := m.servers["srvB"].client
|
||||
m.mu.RUnlock()
|
||||
require.NotNil(t, oldClientB)
|
||||
|
||||
// Remove srvB from the config.
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{"srvA": entryA})
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1)
|
||||
assert.Contains(t, tools[0].Name, "srvA")
|
||||
|
||||
// The old client for srvB should be closed.
|
||||
// ListTools on a closed client returns an error.
|
||||
listCtx, cancel := context.WithTimeout(ctx, testutil.WaitShort)
|
||||
defer cancel()
|
||||
_, listErr := oldClientB.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
assert.Error(t, listErr, "ListTools on closed client should fail")
|
||||
})
|
||||
|
||||
t.Run("ConnectFailureRetainsOldClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.Tools(), 1)
|
||||
|
||||
m.mu.RLock()
|
||||
origClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Change config to use a bad command, so connect fails.
|
||||
path := filepath.Join(dir, ".mcp.json")
|
||||
data := `{"mcpServers":{"srv":{"command":"/nonexistent/binary","args":[]}}}`
|
||||
require.NoError(t, os.WriteFile(path, []byte(data), 0o600))
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The old client should be retained because the new connect
|
||||
// failed.
|
||||
m.mu.RLock()
|
||||
currentClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
assert.Same(t, origClient, currentClient,
|
||||
"failed connect should retain old client")
|
||||
|
||||
// Tools should still work.
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1)
|
||||
})
|
||||
|
||||
t.Run("PostReloadToolCallReachesKeptServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1)
|
||||
toolName := tools[0].Name
|
||||
|
||||
// Add a second server (srv unchanged, so client is reused).
|
||||
_, entry2 := fakeMCPServerConfig(t, "srv2")
|
||||
writeMCPConfig(t, dir, map[string]mcpServerEntry{
|
||||
"srv": entry, "srv2": entry2,
|
||||
})
|
||||
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
// A tool call to the kept server should reach it.
|
||||
// The client pointer for "srv" was reused, not replaced.
|
||||
_, err = m.CallTool(ctx, workspacesdk.CallMCPToolRequest{
|
||||
ToolName: toolName,
|
||||
})
|
||||
// The fake server does not implement tools/call, so we
|
||||
// expect an error from the server, but the call itself
|
||||
// should reach the server (not ErrUnknownServer).
|
||||
require.Error(t, err, "fake server does not implement tools/call")
|
||||
assert.NotErrorIs(t, err, ErrUnknownServer,
|
||||
"tool call should reach the server, not fail with unknown server")
|
||||
})
|
||||
}
|
||||
|
||||
// TestReload_FirstBootPath verifies that the first-boot call site
|
||||
// (agent.go) can be routed through Reload without behavioral change.
|
||||
func TestReload_FirstBootPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Simulate first-boot: Reload with the initial config.
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := m.Tools()
|
||||
require.Len(t, tools, 1)
|
||||
assert.Contains(t, tools[0].Name, "echo")
|
||||
}
|
||||
|
||||
// TestReload_NoopWhenUnchanged verifies that Reload returns
|
||||
// immediately without reconnecting when the snapshot is fresh.
|
||||
func TestReload_NoopWhenUnchanged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mu.RLock()
|
||||
origClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Second reload with no changes should be a no-op.
|
||||
err = m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mu.RLock()
|
||||
sameClient := m.servers["srv"].client
|
||||
m.mu.RUnlock()
|
||||
|
||||
assert.Same(t, origClient, sameClient,
|
||||
"no-op reload should not replace the client")
|
||||
}
|
||||
|
||||
// TestClose_SuppressesSubprocessExitError verifies that Close
|
||||
// returns nil when servers have running subprocesses that exit
|
||||
// with a kill signal during shutdown.
|
||||
func TestClose_SuppressesSubprocessExitError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if os.Getenv("TEST_MCP_FAKE_SERVER") == "1" {
|
||||
runFakeMCPServer()
|
||||
return
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
dir := t.TempDir()
|
||||
|
||||
_, entry := fakeMCPServerConfig(t, "srv")
|
||||
configPath := writeMCPConfig(t, dir, map[string]mcpServerEntry{"srv": entry})
|
||||
|
||||
m := NewManager(ctx, logger, agentexec.DefaultExecer, nil)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
err := m.Reload(ctx, []string{configPath})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.Tools(), 1, "server should be connected")
|
||||
|
||||
// Close kills the subprocess. The ExitError guard should
|
||||
// suppress the "signal: killed" error.
|
||||
err = m.Close()
|
||||
assert.NoError(t, err, "Close should not propagate subprocess kill errors")
|
||||
}
|
||||
@@ -249,8 +249,9 @@ func (p *Server) loadCachedWorkspaceContext(
|
||||
}
|
||||
|
||||
var tools []fantasy.AgentTool
|
||||
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
|
||||
for _, t := range entry.tools {
|
||||
tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn))
|
||||
tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn, invalidate))
|
||||
}
|
||||
|
||||
return tools
|
||||
@@ -6290,9 +6291,10 @@ func (p *Server) runChat(
|
||||
}
|
||||
}
|
||||
|
||||
invalidate := func() { p.workspaceMCPToolsCache.Delete(chat.ID) }
|
||||
for _, t := range toolsResp.Tools {
|
||||
workspaceMCPTools = append(workspaceMCPTools,
|
||||
chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn),
|
||||
chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn, invalidate),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -4,10 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
@@ -16,17 +19,22 @@ import (
|
||||
// connection. It implements fantasy.AgentTool so it can be
|
||||
// registered alongside built-in chat tools.
|
||||
type WorkspaceMCPTool struct {
|
||||
info fantasy.ToolInfo
|
||||
getConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
providerOpts fantasy.ProviderOptions
|
||||
info fantasy.ToolInfo
|
||||
getConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
providerOpts fantasy.ProviderOptions
|
||||
invalidateCache func()
|
||||
}
|
||||
|
||||
// NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo
|
||||
// discovered on a workspace agent. Each tool proxies calls back
|
||||
// through the agent connection.
|
||||
// through the agent connection. The optional invalidateCache
|
||||
// callback is invoked when CallMCPTool returns a 404 error,
|
||||
// indicating that the server was removed and the chat's cached
|
||||
// tool list should be dropped.
|
||||
func NewWorkspaceMCPTool(
|
||||
tool workspacesdk.MCPToolInfo,
|
||||
getConn func(context.Context) (workspacesdk.AgentConn, error),
|
||||
invalidateCache func(),
|
||||
) *WorkspaceMCPTool {
|
||||
required := tool.Required
|
||||
if required == nil {
|
||||
@@ -40,7 +48,8 @@ func NewWorkspaceMCPTool(
|
||||
Required: required,
|
||||
Parallel: true,
|
||||
},
|
||||
getConn: getConn,
|
||||
getConn: getConn,
|
||||
invalidateCache: invalidateCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +84,15 @@ func (t *WorkspaceMCPTool) Run(
|
||||
Arguments: args,
|
||||
})
|
||||
if err != nil {
|
||||
// If the agent returns a 404 (ErrUnknownServer), the
|
||||
// server was removed or renamed. Invalidate the chat's
|
||||
// cached tool list so the next turn refetches.
|
||||
var coderErr *codersdk.Error
|
||||
if errors.As(err, &coderErr) && coderErr.StatusCode() == http.StatusNotFound {
|
||||
if t.invalidateCache != nil {
|
||||
t.invalidateCache()
|
||||
}
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// fakeAgentConn implements just enough of workspacesdk.AgentConn
|
||||
// for testing CallMCPTool.
|
||||
type fakeAgentConn struct {
|
||||
workspacesdk.AgentConn
|
||||
callMCPToolFunc func(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error)
|
||||
}
|
||||
|
||||
func (f *fakeAgentConn) CallMCPTool(ctx context.Context, req workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
return f.callMCPToolFunc(ctx, req)
|
||||
}
|
||||
|
||||
func TestWorkspaceMCPTool_InvalidateOn404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("404ErrorInvalidatesCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var invalidated atomic.Bool
|
||||
tool := chattool.NewWorkspaceMCPTool(
|
||||
workspacesdk.MCPToolInfo{
|
||||
Name: "test__echo",
|
||||
Description: "test tool",
|
||||
},
|
||||
func(ctx context.Context) (workspacesdk.AgentConn, error) {
|
||||
return &fakeAgentConn{
|
||||
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
|
||||
http.StatusNotFound,
|
||||
codersdk.Response{
|
||||
Message: "MCP tool call failed.",
|
||||
Detail: `unknown MCP server: "test"`,
|
||||
},
|
||||
)
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func() { invalidated.Store(true) },
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError, "response should be an error")
|
||||
assert.True(t, invalidated.Load(),
|
||||
"invalidateCache should fire on 404")
|
||||
})
|
||||
|
||||
t.Run("Non404DoesNotInvalidate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var invalidated atomic.Bool
|
||||
tool := chattool.NewWorkspaceMCPTool(
|
||||
workspacesdk.MCPToolInfo{
|
||||
Name: "test__echo",
|
||||
Description: "test tool",
|
||||
},
|
||||
func(ctx context.Context) (workspacesdk.AgentConn, error) {
|
||||
return &fakeAgentConn{
|
||||
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
|
||||
http.StatusBadGateway,
|
||||
codersdk.Response{
|
||||
Message: "Bad Gateway",
|
||||
},
|
||||
)
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func() { invalidated.Store(true) },
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.False(t, invalidated.Load(),
|
||||
"invalidateCache should NOT fire on non-404 error")
|
||||
})
|
||||
|
||||
t.Run("ToolLevelErrorNoInvalidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var invalidated atomic.Bool
|
||||
tool := chattool.NewWorkspaceMCPTool(
|
||||
workspacesdk.MCPToolInfo{
|
||||
Name: "test__echo",
|
||||
Description: "test tool",
|
||||
},
|
||||
func(ctx context.Context) (workspacesdk.AgentConn, error) {
|
||||
return &fakeAgentConn{
|
||||
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
return workspacesdk.CallMCPToolResponse{
|
||||
IsError: true,
|
||||
Content: []workspacesdk.MCPToolContent{
|
||||
{Type: "text", Text: "tool error"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
func() { invalidated.Store(true) },
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.False(t, invalidated.Load(),
|
||||
"invalidateCache should NOT fire on tool-level error (HTTP 200)")
|
||||
})
|
||||
|
||||
t.Run("NilInvalidateCallbackSafe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewWorkspaceMCPTool(
|
||||
workspacesdk.MCPToolInfo{
|
||||
Name: "test__echo",
|
||||
Description: "test tool",
|
||||
},
|
||||
func(ctx context.Context) (workspacesdk.AgentConn, error) {
|
||||
return &fakeAgentConn{
|
||||
callMCPToolFunc: func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) {
|
||||
return workspacesdk.CallMCPToolResponse{}, codersdk.NewError(
|
||||
http.StatusNotFound,
|
||||
codersdk.Response{
|
||||
Message: "MCP tool call failed.",
|
||||
Detail: `unknown MCP server: "test"`,
|
||||
},
|
||||
)
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// Should not panic.
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user