mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
d2f9ad783e
Threads the per-user custom_headers values stored on mcp_server_user_header_values through the chatd MCP client so users who provided a value for an admin-marked CustomHeadersUserKey see it mixed into the outgoing request alongside the admin-static headers. Changes: - mcpclient.ConnectAll grows a fourth indexed input, []database.McpServerUserHeaderValue, which buildAuthHeaders consults inside the custom_headers branch to overlay per-user values on top of admin static headers, scoped to cfg.CustomHeadersUserKeys. - chatd loads the user's stored header values via GetMCPServerUserHeaderValuesByUserID alongside the existing GetMCPServerUserTokensByUserID call and threads them into ConnectAll. A missing row is non-fatal: admin headers still ship, user-keyed headers are simply absent and a warning is logged. - mcpclient.go inlines its own DefaultTransport clone for test isolation, replacing the standalone helper in mcphttpclient.go, which is removed. Stack: 4/6 (chatd runtime overlay)
955 lines
27 KiB
Go
955 lines
27 KiB
Go
package mcpclient
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/client/transport"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/buildinfo"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
)
|
|
|
|
// toolNameSep separates the server slug from the original tool
|
|
// name in prefixed tool names. Double underscore avoids collisions
|
|
// with tool names that may contain single underscores.
|
|
//
|
|
// TODO: tool names that themselves contain "__" produce ambiguous
|
|
// prefixed names (e.g. "srv__my__tool" is indistinguishable from
|
|
// slug "srv" + tool "my__tool" vs slug "srv__my" + tool "tool").
|
|
// This doesn't affect tool invocation since originalName is used
|
|
// directly when calling the remote server.
|
|
const toolNameSep = "__"
|
|
|
|
// connectTimeout bounds how long we wait for a single MCP server
|
|
// to start its transport and complete initialization. Servers that
|
|
// take longer are skipped so one slow server cannot block the
|
|
// entire chat startup.
|
|
const connectTimeout = 10 * time.Second
|
|
|
|
// toolCallTimeout bounds how long a single tool invocation may
|
|
// take before being canceled.
|
|
const toolCallTimeout = 60 * time.Second
|
|
|
|
// UserOIDCTokenSource resolves the OIDC access token for the calling
|
|
// user. Implementations attempt to refresh tokens that are expired
|
|
// or close to expiring and MUST return ("", nil) when the user has
|
|
// no OIDC link or a refresh attempt failed for any reason. A
|
|
// non-nil error is reserved for unexpected infrastructure failures
|
|
// (e.g. database errors) and skips header construction entirely.
|
|
// The empty-token-on-refresh-failure behavior matches
|
|
// provisionerdserver.ObtainOIDCAccessToken.
|
|
type UserOIDCTokenSource interface {
|
|
OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error)
|
|
}
|
|
|
|
// ConnectAll connects to all configured MCP servers, discovers
|
|
// their tools, and returns them as fantasy.AgentTool values.
|
|
// Tools are sorted by their prefixed name so callers
|
|
// receive a deterministic order. It skips servers that fail to
|
|
// connect and logs warnings. The returned cleanup function
|
|
// must be called to close all connections.
|
|
func ConnectAll(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
configs []database.MCPServerConfig,
|
|
tokens []database.MCPServerUserToken,
|
|
userHeaderValues []database.McpServerUserHeaderValue,
|
|
userID uuid.UUID,
|
|
oidcSrc UserOIDCTokenSource,
|
|
coderHeaders map[string]string,
|
|
) ([]fantasy.AgentTool, func()) {
|
|
// Index tokens by server config ID so auth header
|
|
// construction is O(1) per server.
|
|
tokensByConfigID := make(
|
|
map[uuid.UUID]database.MCPServerUserToken, len(tokens),
|
|
)
|
|
for _, tok := range tokens {
|
|
tokensByConfigID[tok.MCPServerConfigID] = tok
|
|
}
|
|
|
|
// Same indexing for the calling user's custom_headers values.
|
|
userHeaderValuesByConfigID := make(
|
|
map[uuid.UUID]database.McpServerUserHeaderValue, len(userHeaderValues),
|
|
)
|
|
for _, hv := range userHeaderValues {
|
|
userHeaderValuesByConfigID[hv.MCPServerConfigID] = hv
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
clients []*client.Client
|
|
tools []fantasy.AgentTool
|
|
)
|
|
|
|
// Build cleanup eagerly so it always closes any clients
|
|
// that connected, even if a later connection fails.
|
|
cleanup := func() {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
for _, c := range clients {
|
|
_ = c.Close()
|
|
}
|
|
clients = nil
|
|
}
|
|
|
|
var eg errgroup.Group
|
|
for _, cfg := range configs {
|
|
if !cfg.Enabled {
|
|
continue
|
|
}
|
|
|
|
eg.Go(func() error {
|
|
serverTools, mcpClient, connectErr := connectOne(
|
|
ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc, coderHeaders,
|
|
)
|
|
if connectErr != nil {
|
|
logger.Warn(ctx,
|
|
"skipping MCP server due to connection failure",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("server_url", RedactURL(cfg.Url)),
|
|
slog.F("error", redactErrorURL(connectErr)),
|
|
)
|
|
// Connection failures are not propagated — the
|
|
// LLM simply won't have this server's tools.
|
|
return nil
|
|
}
|
|
|
|
mu.Lock()
|
|
if mcpClient != nil {
|
|
clients = append(clients, mcpClient)
|
|
}
|
|
tools = append(tools, serverTools...)
|
|
mu.Unlock()
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// All goroutines return nil; error is intentionally
|
|
// discarded.
|
|
_ = eg.Wait()
|
|
|
|
// Sort tools by prefixed name for deterministic ordering
|
|
// regardless of goroutine completion order. Ties, possible
|
|
// when the __ separator produces ambiguous prefixed names,
|
|
// are broken by config ID. Stable prompt construction
|
|
// depends on consistent tool ordering.
|
|
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
|
|
// All tools in this slice are mcpToolWrapper values
|
|
// created by connectOne above, so these checked
|
|
// assertions should always succeed. The config ID
|
|
// tiebreaker resolves the __ separator ambiguity
|
|
// documented at the top of this file.
|
|
aTool, ok := a.(MCPToolIdentifier)
|
|
if !ok {
|
|
panic(fmt.Sprintf("unexpected tool type %T", a))
|
|
}
|
|
bTool, ok := b.(MCPToolIdentifier)
|
|
if !ok {
|
|
panic(fmt.Sprintf("unexpected tool type %T", b))
|
|
}
|
|
return cmp.Or(
|
|
cmp.Compare(a.Info().Name, b.Info().Name),
|
|
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
|
|
)
|
|
})
|
|
|
|
return tools, cleanup
|
|
}
|
|
|
|
// connectOne establishes a connection to a single MCP server,
|
|
// discovers its tools, and wraps each one as an AgentTool with
|
|
// the server slug prefix applied.
|
|
func connectOne(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg database.MCPServerConfig,
|
|
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
|
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
|
|
userID uuid.UUID,
|
|
oidcSrc UserOIDCTokenSource,
|
|
coderHeaders map[string]string,
|
|
) ([]fantasy.AgentTool, *client.Client, error) {
|
|
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc)
|
|
|
|
// When opted-in, merge Coder identity headers BEFORE the
|
|
// transport is created so any auth header already set above
|
|
// wins on a conflict. Conflict detection uses
|
|
// http.CanonicalHeaderKey because the upstream transport applies
|
|
// http.Header.Set, which canonicalizes keys; without that, an
|
|
// admin-configured header that differs only in case from a Coder
|
|
// identity header would land in the request map twice and the
|
|
// surviving value would be non-deterministic.
|
|
if cfg.ForwardCoderHeaders {
|
|
canonicalAuth := make(map[string]struct{}, len(headers))
|
|
for k := range headers {
|
|
canonicalAuth[http.CanonicalHeaderKey(k)] = struct{}{}
|
|
}
|
|
for k, v := range coderHeaders {
|
|
if _, exists := canonicalAuth[http.CanonicalHeaderKey(k)]; exists {
|
|
continue
|
|
}
|
|
headers[k] = v
|
|
}
|
|
}
|
|
|
|
tr, err := createTransport(cfg, headers)
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf(
|
|
"create transport: %w", err,
|
|
)
|
|
}
|
|
|
|
mcpClient := client.NewClient(tr)
|
|
|
|
// The timeout covers the entire connect+init+list sequence,
|
|
// not each phase individually.
|
|
connectCtx, cancel := context.WithTimeout(
|
|
ctx, connectTimeout,
|
|
)
|
|
defer cancel()
|
|
|
|
if err := mcpClient.Start(connectCtx); err != nil {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf(
|
|
"start transport: %w", err,
|
|
)
|
|
}
|
|
|
|
_, err = mcpClient.Initialize(
|
|
connectCtx,
|
|
mcp.InitializeRequest{
|
|
Params: mcp.InitializeParams{
|
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
|
ClientInfo: mcp.Implementation{
|
|
Name: "coder",
|
|
Version: buildinfo.Version(),
|
|
},
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
// Best-effort close so we don't leak the transport.
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf("initialize: %w", err)
|
|
}
|
|
|
|
toolsResult, err := mcpClient.ListTools(
|
|
connectCtx, mcp.ListToolsRequest{},
|
|
)
|
|
if err != nil {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, xerrors.Errorf("list tools: %w", err)
|
|
}
|
|
|
|
var tools []fantasy.AgentTool
|
|
for _, mcpTool := range toolsResult.Tools {
|
|
if !isToolAllowed(
|
|
mcpTool.Name,
|
|
cfg.ToolAllowList,
|
|
cfg.ToolDenyList,
|
|
) {
|
|
logger.Debug(ctx, "skipping denied MCP tool",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("tool_name", mcpTool.Name),
|
|
)
|
|
continue
|
|
}
|
|
|
|
tools = append(
|
|
tools, newMCPTool(cfg.ID, cfg.Slug, mcpTool, mcpClient, cfg.ModelIntent),
|
|
)
|
|
}
|
|
|
|
// If no tools passed filtering, close the client early
|
|
// to avoid holding an idle connection.
|
|
if len(tools) == 0 {
|
|
_ = mcpClient.Close()
|
|
return nil, nil, nil
|
|
}
|
|
|
|
return tools, mcpClient, nil
|
|
}
|
|
|
|
// createTransport builds the appropriate mcp-go transport based
|
|
// on the server's configured transport type.
|
|
func createTransport(
|
|
cfg database.MCPServerConfig,
|
|
headers map[string]string,
|
|
) (transport.Interface, error) {
|
|
// Each connection gets its own HTTP client with a dedicated
|
|
// transport so that httptest.Server.Close() (which calls
|
|
// CloseIdleConnections on http.DefaultTransport) does not
|
|
// disrupt unrelated connections during parallel tests.
|
|
var httpClient *http.Client
|
|
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
|
httpClient = &http.Client{Transport: dt.Clone()}
|
|
} else {
|
|
httpClient = &http.Client{}
|
|
}
|
|
|
|
switch cfg.Transport {
|
|
case "sse":
|
|
return transport.NewSSE(
|
|
cfg.Url,
|
|
transport.WithHeaders(headers),
|
|
transport.WithHTTPClient(httpClient),
|
|
)
|
|
case "", "streamable_http":
|
|
// Default to streamable HTTP, the newer transport.
|
|
return transport.NewStreamableHTTP(
|
|
cfg.Url,
|
|
transport.WithHTTPHeaders(headers),
|
|
transport.WithHTTPBasicClient(httpClient),
|
|
)
|
|
default:
|
|
return nil, xerrors.Errorf(
|
|
"unsupported transport %q", cfg.Transport,
|
|
)
|
|
}
|
|
}
|
|
|
|
// buildAuthHeaders constructs HTTP headers for authenticating
|
|
// with the MCP server based on the configured auth type.
|
|
func buildAuthHeaders(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg database.MCPServerConfig,
|
|
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
|
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
|
|
userID uuid.UUID,
|
|
oidcSrc UserOIDCTokenSource,
|
|
) map[string]string {
|
|
// Using map[string]string rather than http.Header because
|
|
// the mcp-go transport options accept map[string]string.
|
|
// MCP servers typically don't require multi-valued headers.
|
|
headers := make(map[string]string)
|
|
|
|
switch cfg.AuthType {
|
|
case "oauth2":
|
|
tok, ok := tokensByConfigID[cfg.ID]
|
|
if !ok {
|
|
logger.Warn(ctx,
|
|
"no oauth2 token found for MCP server",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
if tok.Expiry.Valid && tok.Expiry.Time.Before(time.Now()) {
|
|
logger.Warn(ctx,
|
|
"oauth2 token for MCP server is expired",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("expired_at", tok.Expiry.Time),
|
|
)
|
|
}
|
|
if tok.AccessToken == "" {
|
|
logger.Warn(ctx,
|
|
"oauth2 token record has empty access token",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
tokenType := tok.TokenType
|
|
if tokenType == "" {
|
|
tokenType = "Bearer"
|
|
}
|
|
// RFC 6750 says the scheme is case-insensitive, but
|
|
// some servers (e.g. Linear) reject lowercase
|
|
// "bearer". Normalize to the canonical form.
|
|
if strings.EqualFold(tokenType, "bearer") {
|
|
tokenType = "Bearer"
|
|
}
|
|
headers["Authorization"] = tokenType + " " + tok.AccessToken
|
|
case "api_key":
|
|
if cfg.APIKeyHeader != "" && cfg.APIKeyValue != "" {
|
|
headers[cfg.APIKeyHeader] = cfg.APIKeyValue
|
|
}
|
|
case "custom_headers":
|
|
if cfg.CustomHeaders != "" {
|
|
var custom map[string]string
|
|
if err := json.Unmarshal(
|
|
[]byte(cfg.CustomHeaders), &custom,
|
|
); err != nil {
|
|
logger.Warn(ctx,
|
|
"failed to parse custom headers JSON",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
} else {
|
|
for k, v := range custom {
|
|
headers[k] = v
|
|
}
|
|
}
|
|
}
|
|
// Overlay user-supplied values for keys the admin marked as
|
|
// user-set. Validation guarantees these are disjoint from
|
|
// CustomHeaders, but the overlay is well-defined either way.
|
|
if len(cfg.CustomHeadersUserKeys) > 0 {
|
|
row, ok := userHeaderValuesByConfigID[cfg.ID]
|
|
if !ok {
|
|
// Normal state: this user has never saved values for
|
|
// this server. The MCP call will proceed without the
|
|
// user-set headers and likely fail at the remote end,
|
|
// which is the expected signal for the UI to prompt
|
|
// the user. Debug-level keeps this off the noise floor.
|
|
logger.Debug(ctx,
|
|
"no user header values for MCP server",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
var user map[string]string
|
|
if err := json.Unmarshal(
|
|
[]byte(row.HeaderValues), &user,
|
|
); err != nil {
|
|
logger.Warn(ctx,
|
|
"failed to parse user header values JSON",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
break
|
|
}
|
|
for _, k := range cfg.CustomHeadersUserKeys {
|
|
// Case-insensitive lookup so a case-only admin rename
|
|
// does not silently drop the user's stored value.
|
|
v, has := mcpHeaderValueForKey(user, k)
|
|
if has && v != "" {
|
|
headers[k] = v
|
|
}
|
|
}
|
|
}
|
|
case "user_oidc":
|
|
// Forward the calling user's OIDC access token from
|
|
// user_links as Authorization: Bearer <token>. The token
|
|
// source is responsible for refreshing tokens that are
|
|
// expired or close to expiring before returning them.
|
|
if oidcSrc == nil || userID == uuid.Nil {
|
|
logger.Warn(ctx,
|
|
"user_oidc auth requested but no token source available",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
token, err := oidcSrc.OIDCAccessToken(ctx, userID)
|
|
if err != nil {
|
|
logger.Warn(ctx,
|
|
"failed to obtain user OIDC token for MCP server",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
break
|
|
}
|
|
if token == "" {
|
|
// The user has no OIDC link, or a non-fatal refresh
|
|
// failure occurred. Fall through with no header and let
|
|
// the upstream MCP server decide how to respond
|
|
// (typically 401). Logged at debug so password and
|
|
// GitHub users don't generate noise for every chat turn.
|
|
logger.Debug(ctx,
|
|
"no user OIDC token available for MCP server",
|
|
slog.F("server_slug", cfg.Slug),
|
|
)
|
|
break
|
|
}
|
|
headers["Authorization"] = "Bearer " + token
|
|
case "none", "":
|
|
// No auth headers needed.
|
|
}
|
|
|
|
return headers
|
|
}
|
|
|
|
// mcpHeaderValueForKey returns the stored value for key using a
|
|
// case-insensitive match. The stored user-header map preserves the
|
|
// admin's casing at write time, so a later case-only rename of a
|
|
// user-set key would otherwise orphan the stored value until the
|
|
// user re-saves it.
|
|
func mcpHeaderValueForKey(stored map[string]string, key string) (string, bool) {
|
|
if v, ok := stored[key]; ok {
|
|
return v, true
|
|
}
|
|
for k, v := range stored {
|
|
if strings.EqualFold(k, key) {
|
|
return v, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
// isToolAllowed checks a tool name against the allow and deny
|
|
// lists. When the allow list is non-empty only tools in it are
|
|
// permitted and the deny list is ignored. When the allow list
|
|
// is empty and the deny list is non-empty, tools in the deny
|
|
// list are rejected. Both lists use exact string matching
|
|
// against the original (non-prefixed) tool name.
|
|
func isToolAllowed(
|
|
toolName string,
|
|
allowList []string,
|
|
denyList []string,
|
|
) bool {
|
|
if len(allowList) > 0 {
|
|
for _, allowed := range allowList {
|
|
if allowed == toolName {
|
|
return true
|
|
}
|
|
}
|
|
// Allow list is set but the tool isn't in it.
|
|
return false
|
|
}
|
|
|
|
for _, denied := range denyList {
|
|
if denied == toolName {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// RedactURL strips userinfo and query parameters from a URL
|
|
// to avoid logging embedded credentials. Query params are
|
|
// removed because API keys are sometimes passed as
|
|
// ?api_key=sk-... in server URLs.
|
|
func RedactURL(rawURL string) string {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return rawURL
|
|
}
|
|
u.User = nil
|
|
u.RawQuery = ""
|
|
u.Fragment = ""
|
|
return u.String()
|
|
}
|
|
|
|
// redactErrorURL rewrites URLs in an error string to strip
|
|
// credentials. Go's net/http embeds the full request URL in
|
|
// *url.Error messages, which can leak userinfo.
|
|
func redactErrorURL(err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
var urlErr *url.Error
|
|
if errors.As(err, &urlErr) {
|
|
urlErr.URL = RedactURL(urlErr.URL)
|
|
return urlErr.Error()
|
|
}
|
|
return err.Error()
|
|
}
|
|
|
|
// MCPToolIdentifier is implemented by tools that originate from
|
|
// an MCP server config and can report the config's database ID.
|
|
type MCPToolIdentifier interface {
|
|
MCPServerConfigID() uuid.UUID
|
|
}
|
|
|
|
// mcpToolWrapper adapts a single MCP tool into a
|
|
// fantasy.AgentTool. It stores the prefixed name for Info() but
|
|
// strips the prefix when forwarding calls to the remote server.
|
|
type mcpToolWrapper struct {
|
|
configID uuid.UUID
|
|
prefixedName string
|
|
originalName string
|
|
description string
|
|
parameters map[string]any
|
|
required []string
|
|
modelIntent bool
|
|
client *client.Client
|
|
providerOptions fantasy.ProviderOptions
|
|
}
|
|
|
|
// MCPServerConfigID returns the database ID of the MCP server
|
|
// config that this tool originates from.
|
|
func (t *mcpToolWrapper) MCPServerConfigID() uuid.UUID {
|
|
return t.configID
|
|
}
|
|
|
|
// newMCPTool creates an mcpToolWrapper from an mcp.Tool
|
|
// discovered on a remote server.
|
|
func newMCPTool(
|
|
configID uuid.UUID,
|
|
serverSlug string,
|
|
tool mcp.Tool,
|
|
mcpClient *client.Client,
|
|
modelIntent bool,
|
|
) *mcpToolWrapper {
|
|
return &mcpToolWrapper{
|
|
configID: configID,
|
|
prefixedName: serverSlug + toolNameSep + tool.Name,
|
|
originalName: tool.Name,
|
|
description: tool.Description,
|
|
parameters: tool.InputSchema.Properties,
|
|
required: tool.InputSchema.Required,
|
|
modelIntent: modelIntent,
|
|
client: mcpClient,
|
|
}
|
|
}
|
|
|
|
func (t *mcpToolWrapper) Info() fantasy.ToolInfo {
|
|
required := t.required
|
|
if required == nil {
|
|
required = []string{}
|
|
}
|
|
|
|
if !t.modelIntent {
|
|
return fantasy.ToolInfo{
|
|
Name: t.prefixedName,
|
|
Description: t.description,
|
|
Parameters: t.parameters,
|
|
Required: required,
|
|
Parallel: true,
|
|
}
|
|
}
|
|
|
|
// Wrap original parameters under "properties" and add
|
|
// "model_intent" so the LLM provides a human-readable
|
|
// description of each tool call.
|
|
wrapped := map[string]any{
|
|
"model_intent": map[string]any{
|
|
"type": "string",
|
|
"description": "A short, natural-language, present-participle " +
|
|
"phrase describing why you are calling this tool. " +
|
|
"This is shown to the user as a status label while " +
|
|
"the tool runs. Use plain English with no underscores " +
|
|
"or technical jargon. Keep it under 100 characters. " +
|
|
"Good examples: \"Reading the authentication module\", " +
|
|
"\"Searching for configuration files\", " +
|
|
"\"Creating a new workspace\".",
|
|
},
|
|
"properties": map[string]any{
|
|
"type": "object",
|
|
"properties": t.parameters,
|
|
"required": required,
|
|
},
|
|
}
|
|
return fantasy.ToolInfo{
|
|
Name: t.prefixedName,
|
|
Description: t.description,
|
|
Parameters: wrapped,
|
|
Required: []string{"model_intent", "properties"},
|
|
Parallel: true,
|
|
}
|
|
}
|
|
|
|
func (t *mcpToolWrapper) Run(
|
|
ctx context.Context,
|
|
params fantasy.ToolCall,
|
|
) (fantasy.ToolResponse, error) {
|
|
input := params.Input
|
|
if t.modelIntent {
|
|
input = unwrapModelIntent(input)
|
|
}
|
|
|
|
var args map[string]any
|
|
if input != "" {
|
|
if err := json.Unmarshal(
|
|
[]byte(input), &args,
|
|
); err != nil {
|
|
return fantasy.NewTextErrorResponse(
|
|
"invalid JSON input: " + err.Error(),
|
|
), nil
|
|
}
|
|
}
|
|
|
|
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
|
|
defer cancel()
|
|
|
|
result, err := t.client.CallTool(
|
|
callCtx,
|
|
mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Name: t.originalName,
|
|
Arguments: args,
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return fantasy.NewTextErrorResponse(err.Error()), nil
|
|
}
|
|
|
|
return convertCallResult(result), nil
|
|
}
|
|
|
|
func (t *mcpToolWrapper) ProviderOptions() fantasy.ProviderOptions {
|
|
return t.providerOptions
|
|
}
|
|
|
|
func (t *mcpToolWrapper) SetProviderOptions(
|
|
opts fantasy.ProviderOptions,
|
|
) {
|
|
t.providerOptions = opts
|
|
}
|
|
|
|
// unwrapModelIntent strips the model_intent wrapper from tool
|
|
// call input so the remote MCP server receives only the original
|
|
// arguments. It handles three shapes the model may produce:
|
|
//
|
|
// 1. { model_intent, properties: {...} } — correct format
|
|
// 2. { model_intent, key: val, ... } — flat, no properties wrapper
|
|
// 3. Anything else — returned as-is
|
|
func unwrapModelIntent(input string) string {
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal([]byte(input), &parsed); err != nil {
|
|
return input
|
|
}
|
|
|
|
delete(parsed, "model_intent")
|
|
|
|
// Case 1: correct { model_intent, properties: {...} } format.
|
|
if props, ok := parsed["properties"]; ok {
|
|
if b, err := json.Marshal(props); err == nil {
|
|
return string(b)
|
|
}
|
|
}
|
|
|
|
// Case 2: flat { model_intent, key: val, ... } without wrapper.
|
|
if b, err := json.Marshal(parsed); err == nil {
|
|
return string(b)
|
|
}
|
|
|
|
return input
|
|
}
|
|
|
|
// convertCallResult translates an MCP CallToolResult into a
|
|
// fantasy.ToolResponse. The fantasy response model supports a
|
|
// single content type per response, so we prioritize text. All
|
|
// text items are collected first. Binary items (image, audio,
|
|
// or embedded blob) are only returned when no text content is
|
|
// available.
|
|
func convertCallResult(
|
|
result *mcp.CallToolResult,
|
|
) fantasy.ToolResponse {
|
|
if result == nil {
|
|
return fantasy.NewTextResponse("")
|
|
}
|
|
|
|
var (
|
|
textParts []string
|
|
binaryResult *fantasy.ToolResponse
|
|
)
|
|
for _, item := range result.Content {
|
|
switch c := item.(type) {
|
|
case mcp.TextContent:
|
|
textParts = append(textParts, strings.ToValidUTF8(c.Text, "\uFFFD"))
|
|
case mcp.ImageContent:
|
|
data, err := base64.StdEncoding.DecodeString(
|
|
c.Data,
|
|
)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[image decode error: "+err.Error()+"]",
|
|
)
|
|
continue
|
|
}
|
|
if binaryResult == nil {
|
|
r := fantasy.ToolResponse{
|
|
Type: "image",
|
|
Data: data,
|
|
MediaType: c.MIMEType,
|
|
IsError: result.IsError,
|
|
}
|
|
binaryResult = &r
|
|
}
|
|
case mcp.AudioContent:
|
|
data, err := base64.StdEncoding.DecodeString(
|
|
c.Data,
|
|
)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[audio decode error: "+err.Error()+"]",
|
|
)
|
|
continue
|
|
}
|
|
if binaryResult == nil {
|
|
r := fantasy.ToolResponse{
|
|
Type: "media",
|
|
Data: data,
|
|
MediaType: c.MIMEType,
|
|
IsError: result.IsError,
|
|
}
|
|
binaryResult = &r
|
|
}
|
|
case mcp.EmbeddedResource:
|
|
// Embedded resources wrap either text or blob
|
|
// content from an MCP resource. We handle each
|
|
// variant so the LLM receives the content
|
|
// regardless of form.
|
|
switch r := c.Resource.(type) {
|
|
case mcp.TextResourceContents:
|
|
textParts = append(textParts, strings.ToValidUTF8(r.Text, "\uFFFD"))
|
|
case mcp.BlobResourceContents:
|
|
data, err := base64.StdEncoding.DecodeString(
|
|
r.Blob,
|
|
)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[blob decode error: "+err.Error()+"]",
|
|
)
|
|
continue
|
|
}
|
|
if binaryResult == nil {
|
|
blobType := "media"
|
|
if strings.HasPrefix(r.MIMEType, "image/") {
|
|
blobType = "image"
|
|
}
|
|
res := fantasy.ToolResponse{
|
|
Type: blobType,
|
|
Data: data,
|
|
MediaType: r.MIMEType,
|
|
IsError: result.IsError,
|
|
}
|
|
binaryResult = &res
|
|
}
|
|
default:
|
|
textParts = append(textParts,
|
|
fmt.Sprintf(
|
|
"[unsupported embedded resource type: %T]",
|
|
c.Resource,
|
|
),
|
|
)
|
|
}
|
|
case mcp.ResourceLink:
|
|
// Resource links point to content the LLM can
|
|
// reference by URI. Surface the URI so the model
|
|
// can use it in follow-ups.
|
|
label := c.URI
|
|
if c.Name != "" {
|
|
label = fmt.Sprintf("%s (%s)", c.Name, c.URI)
|
|
}
|
|
if c.Description != "" {
|
|
label += ": " + c.Description
|
|
}
|
|
textParts = append(textParts,
|
|
fmt.Sprintf("[resource: %s]", label),
|
|
)
|
|
default:
|
|
textParts = append(textParts,
|
|
fmt.Sprintf("[unsupported content type: %T]", c),
|
|
)
|
|
}
|
|
}
|
|
|
|
// If structured content is present, marshal it to JSON and
|
|
// append as a text part so the data is preserved for the LLM.
|
|
if result.StructuredContent != nil {
|
|
data, err := json.Marshal(result.StructuredContent)
|
|
if err != nil {
|
|
textParts = append(textParts,
|
|
"[structured content marshal error: "+
|
|
err.Error()+"]",
|
|
)
|
|
} else {
|
|
textParts = append(textParts, string(data))
|
|
}
|
|
}
|
|
|
|
// Prefer text content. Only fall back to binary when no
|
|
// text was collected.
|
|
if len(textParts) > 0 {
|
|
resp := fantasy.NewTextResponse(
|
|
strings.Join(textParts, "\n"),
|
|
)
|
|
resp.IsError = result.IsError
|
|
return resp
|
|
}
|
|
if binaryResult != nil {
|
|
return *binaryResult
|
|
}
|
|
return fantasy.NewTextResponse("")
|
|
}
|
|
|
|
// RefreshResult contains the outcome of an OAuth2 token refresh
|
|
// attempt.
|
|
type RefreshResult struct {
|
|
// AccessToken is the new (or unchanged) access token.
|
|
AccessToken string
|
|
// RefreshToken is the new (or preserved original) refresh
|
|
// token. Providers that don't rotate refresh tokens return
|
|
// an empty value; in that case the original is kept.
|
|
RefreshToken string
|
|
// TokenType is the token type (usually "Bearer").
|
|
TokenType string
|
|
// Expiry is the new token expiry. Zero value means no expiry
|
|
// was provided by the provider.
|
|
Expiry time.Time
|
|
// Refreshed is true when the access token actually changed,
|
|
// meaning a refresh occurred. When false the token was still
|
|
// valid and no network call was made.
|
|
Refreshed bool
|
|
}
|
|
|
|
// RefreshOAuth2Token checks whether the given MCP user token is
|
|
// expired (or within 10 seconds of expiry) and refreshes it using
|
|
// the OAuth2 credentials from the server config. If the token is
|
|
// still valid, no network call is made and Refreshed is false.
|
|
//
|
|
// The caller is responsible for persisting the result when
|
|
// Refreshed is true.
|
|
func RefreshOAuth2Token(
|
|
ctx context.Context,
|
|
cfg database.MCPServerConfig,
|
|
tok database.MCPServerUserToken,
|
|
) (RefreshResult, error) {
|
|
oauth2Cfg := &oauth2.Config{
|
|
ClientID: cfg.OAuth2ClientID,
|
|
ClientSecret: cfg.OAuth2ClientSecret,
|
|
Endpoint: oauth2.Endpoint{
|
|
TokenURL: cfg.OAuth2TokenURL,
|
|
},
|
|
}
|
|
|
|
oldToken := &oauth2.Token{
|
|
AccessToken: tok.AccessToken,
|
|
RefreshToken: tok.RefreshToken,
|
|
TokenType: tok.TokenType,
|
|
}
|
|
if tok.Expiry.Valid {
|
|
oldToken.Expiry = tok.Expiry.Time
|
|
}
|
|
|
|
// Cap the refresh HTTP call so a stalled token endpoint
|
|
// cannot block the entire MCP connection phase. The timeout
|
|
// matches connectTimeout used for MCP server connections.
|
|
refreshCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
|
defer cancel()
|
|
|
|
// TokenSource automatically refreshes expired tokens. It
|
|
// uses a 10-second expiry window, so tokens about to expire
|
|
// are also refreshed proactively.
|
|
newToken, err := oauth2Cfg.TokenSource(refreshCtx, oldToken).Token()
|
|
if err != nil {
|
|
return RefreshResult{}, xerrors.Errorf("refresh oauth2 token: %w", err)
|
|
}
|
|
|
|
refreshed := newToken.AccessToken != tok.AccessToken
|
|
|
|
// Preserve the old refresh token when the provider doesn't
|
|
// rotate (returns empty).
|
|
refreshToken := cmp.Or(newToken.RefreshToken, tok.RefreshToken)
|
|
|
|
return RefreshResult{
|
|
AccessToken: newToken.AccessToken,
|
|
RefreshToken: refreshToken,
|
|
TokenType: newToken.TokenType,
|
|
Expiry: newToken.Expiry,
|
|
Refreshed: refreshed,
|
|
}, nil
|
|
}
|