mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
e388a88592
## Summary
Adds a new `coderd/chatd/mcpclient` package that connects to
admin-configured MCP servers and wraps their tools as
`fantasy.AgentTool` values that the chat loop can invoke.
## What changed
### New: `coderd/chatd/mcpclient/mcpclient.go`
The core package with a single entry point:
```go
func ConnectAll(
ctx context.Context,
logger slog.Logger,
configs []database.MCPServerConfig,
tokens []database.MCPServerUserToken,
) (tools []fantasy.AgentTool, cleanup func(), err error)
```
This:
1. Connects to each enabled MCP server using `mark3labs/mcp-go`
(streamable HTTP or SSE transport)
2. Discovers tools via the MCP `tools/list` method
3. Wraps each tool as a `fantasy.AgentTool` with namespaced name
(`serverslug__toolname`)
4. Applies tool allow/deny list filtering from the server config
5. Handles auth: OAuth2 bearer tokens, API keys, and custom headers
6. Skips broken servers with a warning (10s connect timeout per server)
7. Returns a cleanup function to close all MCP connections
### Modified: `coderd/chatd/chatd.go`
In `runChat()`, after loading the model/messages but before assembling
the tool list:
- Reads `chat.MCPServerIDs` from the chat record
- Loads the MCP server configs from the database
- Resolves the user's auth tokens
- Calls `mcpclient.ConnectAll()` to connect and discover tools
- Appends the MCP tools to the chat's tool set
- Defers cleanup to close connections when the chat turn ends
The chat loop (`chatloop.Run`) already handles tools generically —
MCP-backed tools are invoked identically to built-in workspace tools. No
changes needed in `chatloop/`.
### New: `coderd/chatd/mcpclient/mcpclient_test.go`
10 tests covering:
- Tool discovery and namespacing
- Tool call forwarding and result conversion
- Allow/deny list filtering
- Connection failure handling (graceful skip)
- Multi-server support with correct prefixes
- OAuth2 auth header injection
- Disabled server skipping
- Invalid input handling
- Tool info parameter propagation
## Design decisions
- **Tool namespacing**: `slug__toolname` with double underscore
separator. Avoids collisions with tools containing single underscores.
Stripped when forwarding to `tools/call`.
- **Connection lifecycle**: Fresh connections per chat turn, closed via
`defer`. Matches the `turnWorkspaceContext` pattern.
- **Failure isolation**: Each server connects independently. A broken
server doesn't fail the chat — its tools are simply unavailable.
- **No chatloop changes**: The existing `[]fantasy.AgentTool` interface
is already fully generic.
## What's NOT in this PR (follow-ups)
- Frontend MCP server picker UI (selecting servers for a chat)
- System prompt additions describing available MCP tools
- Token refresh on expiry mid-chat
- The deprecated `aibridged` MCP proxy cleanup
543 lines
13 KiB
Go
543 lines
13 KiB
Go
package mcpclient
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/client/transport"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/buildinfo"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
)
|
|
|
|
// toolNameSep separates the server slug from the original tool
|
|
// name in prefixed tool names. Double underscore avoids collisions
|
|
// with tool names that may contain single underscores.
|
|
//
|
|
// TODO: tool names that themselves contain "__" produce ambiguous
|
|
// prefixed names (e.g. "srv__my__tool" is indistinguishable from
|
|
// slug "srv" + tool "my__tool" vs slug "srv__my" + tool "tool").
|
|
// This doesn't affect tool invocation since originalName is used
|
|
// directly when calling the remote server.
|
|
const toolNameSep = "__"
|
|
|
|
// connectTimeout bounds how long we wait for a single MCP server
|
|
// to start its transport and complete initialization. Servers that
|
|
// take longer are skipped so one slow server cannot block the
|
|
// entire chat startup.
|
|
const connectTimeout = 10 * time.Second
|
|
|
|
// toolCallTimeout bounds how long a single tool invocation may
|
|
// take before being canceled.
|
|
const toolCallTimeout = 60 * time.Second
|
|
|
|
// ConnectAll connects to all configured MCP servers, discovers
|
|
// their tools, and returns them as fantasy.AgentTool values. It
|
|
// skips servers that fail to connect and logs warnings. The
|
|
// returned cleanup function must be called to close all
|
|
// connections.
|
|
func ConnectAll(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
configs []database.MCPServerConfig,
|
|
tokens []database.MCPServerUserToken,
|
|
) ([]fantasy.AgentTool, func()) {
|
|
// Index tokens by server config ID so auth header
|
|
// construction is O(1) per server.
|
|
tokensByConfigID := make(
|
|
map[uuid.UUID]database.MCPServerUserToken, len(tokens),
|
|
)
|
|
for _, tok := range tokens {
|
|
tokensByConfigID[tok.MCPServerConfigID] = tok
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
clients []*client.Client
|
|
tools []fantasy.AgentTool
|
|
)
|
|
|
|
// Build cleanup eagerly so it always closes any clients
|
|
// that connected, even if a later connection fails.
|
|
cleanup := func() {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
for _, c := range clients {
|
|
_ = c.Close()
|
|
}
|
|
clients = nil
|
|
}
|
|
|
|
var eg errgroup.Group
|
|
for _, cfg := range configs {
|
|
if !cfg.Enabled {
|
|
continue
|
|
}
|
|
|
|
eg.Go(func() error {
|
|
serverTools, mcpClient, connectErr := connectOne(
|
|
ctx, logger, cfg, tokensByConfigID,
|
|
)
|
|
if connectErr != nil {
|
|
logger.Warn(ctx,
|
|
"skipping MCP server due to connection failure",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("server_url", RedactURL(cfg.Url)),
|
|
slog.F("error", redactErrorURL(connectErr)),
|
|
)
|
|
// Connection failures are not propagated — the
|
|
// LLM simply won't have this server's tools.
|
|
return nil
|
|
}
|
|
|
|
mu.Lock()
|
|
clients = append(clients, mcpClient)
|
|
tools = append(tools, serverTools...)
|
|
mu.Unlock()
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// All goroutines return nil; error is intentionally
|
|
// discarded.
|
|
_ = eg.Wait()
|
|
|
|
return tools, cleanup
|
|
}
|
|
|
|
// connectOne establishes a connection to a single MCP server,
|
|
// discovers its tools, and wraps each one as an AgentTool with
|
|
// the server slug prefix applied.
|
|
func connectOne(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg database.MCPServerConfig,
|
|
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
|
) ([]fantasy.AgentTool, *client.Client, error) {
|
|
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID)
|
|
|
|
tr, err := createTransport(cfg, headers)
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf(
|
|
"create transport: %w", err,
|
|
)
|
|
}
|
|
|
|
mcpClient := client.NewClient(tr)
|
|
|
|
// The timeout covers the entire connect+init+list sequence,
|
|
// not each phase individually.
|
|
connectCtx, cancel := context.WithTimeout(
|
|
ctx, connectTimeout,
|
|
)
|
|
defer cancel()
|
|
|
|
if err := mcpClient.Start(connectCtx); err != nil {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf(
|
|
"start transport: %w", err,
|
|
)
|
|
}
|
|
|
|
_, err = mcpClient.Initialize(
|
|
connectCtx,
|
|
mcp.InitializeRequest{
|
|
Params: mcp.InitializeParams{
|
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
|
ClientInfo: mcp.Implementation{
|
|
Name: "coder",
|
|
Version: buildinfo.Version(),
|
|
},
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
// Best-effort close so we don't leak the transport.
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf("initialize: %w", err)
|
|
}
|
|
|
|
toolsResult, err := mcpClient.ListTools(
|
|
connectCtx, mcp.ListToolsRequest{},
|
|
)
|
|
if err != nil {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf("list tools: %w", err)
|
|
}
|
|
|
|
var tools []fantasy.AgentTool
|
|
for _, mcpTool := range toolsResult.Tools {
|
|
if !isToolAllowed(
|
|
mcpTool.Name,
|
|
cfg.ToolAllowList,
|
|
cfg.ToolDenyList,
|
|
) {
|
|
logger.Debug(ctx, "skipping denied MCP tool",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("tool_name", mcpTool.Name),
|
|
)
|
|
continue
|
|
}
|
|
|
|
tools = append(
|
|
tools, newMCPTool(cfg.Slug, mcpTool, mcpClient),
|
|
)
|
|
}
|
|
|
|
// If no tools passed filtering, close the client early
|
|
// to avoid holding an idle connection.
|
|
if len(tools) == 0 {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, nil
|
|
}
|
|
|
|
return tools, mcpClient, nil
|
|
}
|
|
|
|
// createTransport builds the appropriate mcp-go transport based
|
|
// on the server's configured transport type.
|
|
func createTransport(
|
|
cfg database.MCPServerConfig,
|
|
headers map[string]string,
|
|
) (transport.Interface, error) {
|
|
switch cfg.Transport {
|
|
case "sse":
|
|
return transport.NewSSE(
|
|
cfg.Url,
|
|
transport.WithHeaders(headers),
|
|
)
|
|
case "", "streamable_http":
|
|
// Default to streamable HTTP, the newer transport.
|
|
return transport.NewStreamableHTTP(
|
|
cfg.Url,
|
|
transport.WithHTTPHeaders(headers),
|
|
)
|
|
default:
|
|
return nil, xerrors.Errorf(
|
|
"unsupported transport %q", cfg.Transport,
|
|
)
|
|
}
|
|
}
|
|
|
|
// buildAuthHeaders constructs HTTP headers for authenticating
|
|
// with the MCP server based on the configured auth type.
|
|
func buildAuthHeaders(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg database.MCPServerConfig,
|
|
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
|
) map[string]string {
|
|
// Using map[string]string rather than http.Header because
|
|
// the mcp-go transport options accept map[string]string.
|
|
// MCP servers typically don't require multi-valued headers.
|
|
headers := make(map[string]string)
|
|
|
|
switch cfg.AuthType {
|
|
case "oauth2":
|
|
tok, ok := tokensByConfigID[cfg.ID]
|
|
if !ok {
|
|
logger.Warn(ctx,
|
|
"no oauth2 token found for MCP server",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
if tok.Expiry.Valid && tok.Expiry.Time.Before(time.Now()) {
|
|
logger.Warn(ctx,
|
|
"oauth2 token for MCP server is expired",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("expired_at", tok.Expiry.Time),
|
|
)
|
|
}
|
|
if tok.AccessToken == "" {
|
|
logger.Warn(ctx,
|
|
"oauth2 token record has empty access token",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
tokenType := tok.TokenType
|
|
if tokenType == "" {
|
|
tokenType = "Bearer"
|
|
}
|
|
headers["Authorization"] = tokenType + " " + tok.AccessToken
|
|
case "api_key":
|
|
if cfg.APIKeyHeader != "" && cfg.APIKeyValue != "" {
|
|
headers[cfg.APIKeyHeader] = cfg.APIKeyValue
|
|
}
|
|
case "custom_headers":
|
|
if cfg.CustomHeaders != "" {
|
|
var custom map[string]string
|
|
if err := json.Unmarshal(
|
|
[]byte(cfg.CustomHeaders), &custom,
|
|
); err != nil {
|
|
logger.Warn(ctx,
|
|
"failed to parse custom headers JSON",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
} else {
|
|
for k, v := range custom {
|
|
headers[k] = v
|
|
}
|
|
}
|
|
}
|
|
case "none", "":
|
|
// No auth headers needed.
|
|
}
|
|
|
|
return headers
|
|
}
|
|
|
|
// isToolAllowed checks a tool name against the allow and deny
|
|
// lists. When the allow list is non-empty only tools in it are
|
|
// permitted and the deny list is ignored. When the allow list
|
|
// is empty and the deny list is non-empty, tools in the deny
|
|
// list are rejected. Both lists use exact string matching
|
|
// against the original (non-prefixed) tool name.
|
|
func isToolAllowed(
|
|
toolName string,
|
|
allowList []string,
|
|
denyList []string,
|
|
) bool {
|
|
if len(allowList) > 0 {
|
|
for _, allowed := range allowList {
|
|
if allowed == toolName {
|
|
return true
|
|
}
|
|
}
|
|
// Allow list is set but the tool isn't in it.
|
|
return false
|
|
}
|
|
|
|
for _, denied := range denyList {
|
|
if denied == toolName {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// RedactURL strips userinfo and query parameters from a URL
|
|
// to avoid logging embedded credentials. Query params are
|
|
// removed because API keys are sometimes passed as
|
|
// ?api_key=sk-... in server URLs.
|
|
func RedactURL(rawURL string) string {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return rawURL
|
|
}
|
|
u.User = nil
|
|
u.RawQuery = ""
|
|
u.Fragment = ""
|
|
return u.String()
|
|
}
|
|
|
|
// redactErrorURL rewrites URLs in an error string to strip
|
|
// credentials. Go's net/http embeds the full request URL in
|
|
// *url.Error messages, which can leak userinfo.
|
|
func redactErrorURL(err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
var urlErr *url.Error
|
|
if errors.As(err, &urlErr) {
|
|
urlErr.URL = RedactURL(urlErr.URL)
|
|
return urlErr.Error()
|
|
}
|
|
return err.Error()
|
|
}
|
|
|
|
// mcpToolWrapper adapts a single MCP tool into a
|
|
// fantasy.AgentTool. It stores the prefixed name for Info() but
|
|
// strips the prefix when forwarding calls to the remote server.
|
|
type mcpToolWrapper struct {
|
|
prefixedName string
|
|
originalName string
|
|
description string
|
|
parameters map[string]any
|
|
required []string
|
|
client *client.Client
|
|
providerOptions fantasy.ProviderOptions
|
|
}
|
|
|
|
// newMCPTool creates an mcpToolWrapper from an mcp.Tool
|
|
// discovered on a remote server.
|
|
func newMCPTool(
|
|
serverSlug string,
|
|
tool mcp.Tool,
|
|
mcpClient *client.Client,
|
|
) *mcpToolWrapper {
|
|
return &mcpToolWrapper{
|
|
prefixedName: serverSlug + toolNameSep + tool.Name,
|
|
originalName: tool.Name,
|
|
description: tool.Description,
|
|
parameters: tool.InputSchema.Properties,
|
|
required: tool.InputSchema.Required,
|
|
client: mcpClient,
|
|
}
|
|
}
|
|
|
|
func (t *mcpToolWrapper) Info() fantasy.ToolInfo {
|
|
return fantasy.ToolInfo{
|
|
Name: t.prefixedName,
|
|
Description: t.description,
|
|
Parameters: t.parameters,
|
|
Required: t.required,
|
|
Parallel: true,
|
|
}
|
|
}
|
|
|
|
func (t *mcpToolWrapper) Run(
|
|
ctx context.Context,
|
|
params fantasy.ToolCall,
|
|
) (fantasy.ToolResponse, error) {
|
|
var args map[string]any
|
|
if params.Input != "" {
|
|
if err := json.Unmarshal(
|
|
[]byte(params.Input), &args,
|
|
); err != nil {
|
|
return fantasy.NewTextErrorResponse(
|
|
"invalid JSON input: " + err.Error(),
|
|
), nil
|
|
}
|
|
}
|
|
|
|
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
|
|
defer cancel()
|
|
|
|
result, err := t.client.CallTool(
|
|
callCtx,
|
|
mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Name: t.originalName,
|
|
Arguments: args,
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return fantasy.NewTextErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
return convertCallResult(result), nil
|
|
}
|
|
|
|
func (t *mcpToolWrapper) ProviderOptions() fantasy.ProviderOptions {
|
|
return t.providerOptions
|
|
}
|
|
|
|
func (t *mcpToolWrapper) SetProviderOptions(
|
|
opts fantasy.ProviderOptions,
|
|
) {
|
|
t.providerOptions = opts
|
|
}
|
|
|
|
// convertCallResult translates an MCP CallToolResult into a
|
|
// fantasy.ToolResponse. The fantasy response model supports a
|
|
// single content type per response, so we prioritize text. All
|
|
// text items are collected first. Binary items (image or audio)
|
|
// are only returned when no text content is available.
|
|
func convertCallResult(
|
|
result *mcp.CallToolResult,
|
|
) fantasy.ToolResponse {
|
|
if result == nil {
|
|
return fantasy.NewTextResponse("")
|
|
}
|
|
|
|
var (
|
|
textParts []string
|
|
binaryResult *fantasy.ToolResponse
|
|
)
|
|
for _, item := range result.Content {
|
|
switch c := item.(type) {
|
|
case mcp.TextContent:
|
|
textParts = append(textParts, c.Text)
|
|
case mcp.ImageContent:
|
|
data, err := base64.StdEncoding.DecodeString(
|
|
c.Data,
|
|
)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[image decode error: "+err.Error()+"]",
|
|
)
|
|
continue
|
|
}
|
|
if binaryResult == nil {
|
|
r := fantasy.ToolResponse{
|
|
Type: "image",
|
|
Data: data,
|
|
MediaType: c.MIMEType,
|
|
IsError: result.IsError,
|
|
}
|
|
binaryResult = &r
|
|
}
|
|
case mcp.AudioContent:
|
|
data, err := base64.StdEncoding.DecodeString(
|
|
c.Data,
|
|
)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[audio decode error: "+err.Error()+"]",
|
|
)
|
|
continue
|
|
}
|
|
if binaryResult == nil {
|
|
r := fantasy.ToolResponse{
|
|
Type: "media",
|
|
Data: data,
|
|
MediaType: c.MIMEType,
|
|
IsError: result.IsError,
|
|
}
|
|
binaryResult = &r
|
|
}
|
|
default:
|
|
textParts = append(textParts,
|
|
fmt.Sprintf("[unsupported content type: %T]", c),
|
|
)
|
|
}
|
|
}
|
|
|
|
// If structured content is present, marshal it to JSON and
|
|
// append as a text part so the data is preserved for the LLM.
|
|
if result.StructuredContent != nil {
|
|
data, err := json.Marshal(result.StructuredContent)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[structured content marshal error: "+
|
|
err.Error()+"]",
|
|
)
|
|
} else {
|
|
textParts = append(textParts, string(data))
|
|
}
|
|
}
|
|
|
|
// Prefer text content. Only fall back to binary when no
|
|
// text was collected.
|
|
if len(textParts) > 0 {
|
|
resp := fantasy.NewTextResponse(
|
|
strings.Join(textParts, "\n"),
|
|
)
|
|
resp.IsError = result.IsError
|
|
return resp
|
|
}
|
|
if binaryResult != nil {
|
|
return *binaryResult
|
|
}
|
|
return fantasy.NewTextResponse("")
|
|
}
|