Files
coder/coderd/x/chatd/mcpclient/mcpclient.go
T
Steven Masley d2f9ad783e feat(coderd/x/chatd): overlay user-set custom_headers at runtime
Threads the per-user custom_headers values stored on
mcp_server_user_header_values through the chatd MCP client so users
who provided a value for an admin-marked CustomHeadersUserKey see it
mixed into the outgoing request alongside the admin-static headers.

Changes:

- mcpclient.ConnectAll grows a fourth indexed input,
  []database.McpServerUserHeaderValue, which buildAuthHeaders
  consults inside the custom_headers branch to overlay per-user
  values on top of admin static headers, scoped to
  cfg.CustomHeadersUserKeys.
- chatd loads the user's stored header values via
  GetMCPServerUserHeaderValuesByUserID alongside the existing
  GetMCPServerUserTokensByUserID call and threads them into
  ConnectAll. A missing row is non-fatal: admin headers still
  ship, user-keyed headers are simply absent and a warning is
  logged.
- mcpclient.go inlines its own DefaultTransport clone for test
  isolation, replacing the standalone helper in mcphttpclient.go,
  which is removed.

Stack: 4/6 (chatd runtime overlay)
2026-06-01 15:02:34 +00:00

955 lines
27 KiB
Go

package mcpclient
import (
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/database"
)
// toolNameSep separates the server slug from the original tool
// name in prefixed tool names. Double underscore avoids collisions
// with tool names that may contain single underscores.
//
// TODO: tool names that themselves contain "__" produce ambiguous
// prefixed names (e.g. "srv__my__tool" is indistinguishable from
// slug "srv" + tool "my__tool" vs slug "srv__my" + tool "tool").
// This doesn't affect tool invocation since originalName is used
// directly when calling the remote server.
const toolNameSep = "__"
// connectTimeout bounds how long we wait for a single MCP server
// to start its transport and complete initialization. Servers that
// take longer are skipped so one slow server cannot block the
// entire chat startup.
const connectTimeout = 10 * time.Second
// toolCallTimeout bounds how long a single tool invocation may
// take before being canceled.
const toolCallTimeout = 60 * time.Second
// UserOIDCTokenSource resolves the OIDC access token for the calling
// user. Implementations attempt to refresh tokens that are expired
// or close to expiring and MUST return ("", nil) when the user has
// no OIDC link or a refresh attempt failed for any reason. A
// non-nil error is reserved for unexpected infrastructure failures
// (e.g. database errors) and skips header construction entirely.
// The empty-token-on-refresh-failure behavior matches
// provisionerdserver.ObtainOIDCAccessToken.
type UserOIDCTokenSource interface {
OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error)
}
// ConnectAll connects to all configured MCP servers, discovers
// their tools, and returns them as fantasy.AgentTool values.
// Tools are sorted by their prefixed name so callers
// receive a deterministic order. It skips servers that fail to
// connect and logs warnings. The returned cleanup function
// must be called to close all connections.
func ConnectAll(
ctx context.Context,
logger slog.Logger,
configs []database.MCPServerConfig,
tokens []database.MCPServerUserToken,
userHeaderValues []database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
coderHeaders map[string]string,
) ([]fantasy.AgentTool, func()) {
// Index tokens by server config ID so auth header
// construction is O(1) per server.
tokensByConfigID := make(
map[uuid.UUID]database.MCPServerUserToken, len(tokens),
)
for _, tok := range tokens {
tokensByConfigID[tok.MCPServerConfigID] = tok
}
// Same indexing for the calling user's custom_headers values.
userHeaderValuesByConfigID := make(
map[uuid.UUID]database.McpServerUserHeaderValue, len(userHeaderValues),
)
for _, hv := range userHeaderValues {
userHeaderValuesByConfigID[hv.MCPServerConfigID] = hv
}
var (
mu sync.Mutex
clients []*client.Client
tools []fantasy.AgentTool
)
// Build cleanup eagerly so it always closes any clients
// that connected, even if a later connection fails.
cleanup := func() {
mu.Lock()
defer mu.Unlock()
for _, c := range clients {
_ = c.Close()
}
clients = nil
}
var eg errgroup.Group
for _, cfg := range configs {
if !cfg.Enabled {
continue
}
eg.Go(func() error {
serverTools, mcpClient, connectErr := connectOne(
ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc, coderHeaders,
)
if connectErr != nil {
logger.Warn(ctx,
"skipping MCP server due to connection failure",
slog.F("server_slug", cfg.Slug),
slog.F("server_url", RedactURL(cfg.Url)),
slog.F("error", redactErrorURL(connectErr)),
)
// Connection failures are not propagated — the
// LLM simply won't have this server's tools.
return nil
}
mu.Lock()
if mcpClient != nil {
clients = append(clients, mcpClient)
}
tools = append(tools, serverTools...)
mu.Unlock()
return nil
})
}
// All goroutines return nil; error is intentionally
// discarded.
_ = eg.Wait()
// Sort tools by prefixed name for deterministic ordering
// regardless of goroutine completion order. Ties, possible
// when the __ separator produces ambiguous prefixed names,
// are broken by config ID. Stable prompt construction
// depends on consistent tool ordering.
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
// All tools in this slice are mcpToolWrapper values
// created by connectOne above, so these checked
// assertions should always succeed. The config ID
// tiebreaker resolves the __ separator ambiguity
// documented at the top of this file.
aTool, ok := a.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", a))
}
bTool, ok := b.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", b))
}
return cmp.Or(
cmp.Compare(a.Info().Name, b.Info().Name),
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
)
})
return tools, cleanup
}
// connectOne establishes a connection to a single MCP server,
// discovers its tools, and wraps each one as an AgentTool with
// the server slug prefix applied.
func connectOne(
ctx context.Context,
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
coderHeaders map[string]string,
) ([]fantasy.AgentTool, *client.Client, error) {
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc)
// When opted-in, merge Coder identity headers BEFORE the
// transport is created so any auth header already set above
// wins on a conflict. Conflict detection uses
// http.CanonicalHeaderKey because the upstream transport applies
// http.Header.Set, which canonicalizes keys; without that, an
// admin-configured header that differs only in case from a Coder
// identity header would land in the request map twice and the
// surviving value would be non-deterministic.
if cfg.ForwardCoderHeaders {
canonicalAuth := make(map[string]struct{}, len(headers))
for k := range headers {
canonicalAuth[http.CanonicalHeaderKey(k)] = struct{}{}
}
for k, v := range coderHeaders {
if _, exists := canonicalAuth[http.CanonicalHeaderKey(k)]; exists {
continue
}
headers[k] = v
}
}
tr, err := createTransport(cfg, headers)
if err != nil {
return nil, nil, xerrors.Errorf(
"create transport: %w", err,
)
}
mcpClient := client.NewClient(tr)
// The timeout covers the entire connect+init+list sequence,
// not each phase individually.
connectCtx, cancel := context.WithTimeout(
ctx, connectTimeout,
)
defer cancel()
if err := mcpClient.Start(connectCtx); err != nil {
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf(
"start transport: %w", err,
)
}
_, err = mcpClient.Initialize(
connectCtx,
mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder",
Version: buildinfo.Version(),
},
},
},
)
if err != nil {
// Best-effort close so we don't leak the transport.
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf("initialize: %w", err)
}
toolsResult, err := mcpClient.ListTools(
connectCtx, mcp.ListToolsRequest{},
)
if err != nil {
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf("list tools: %w", err)
}
var tools []fantasy.AgentTool
for _, mcpTool := range toolsResult.Tools {
if !isToolAllowed(
mcpTool.Name,
cfg.ToolAllowList,
cfg.ToolDenyList,
) {
logger.Debug(ctx, "skipping denied MCP tool",
slog.F("server_slug", cfg.Slug),
slog.F("tool_name", mcpTool.Name),
)
continue
}
tools = append(
tools, newMCPTool(cfg.ID, cfg.Slug, mcpTool, mcpClient, cfg.ModelIntent),
)
}
// If no tools passed filtering, close the client early
// to avoid holding an idle connection.
if len(tools) == 0 {
_ = mcpClient.Close()
return nil, nil, nil
}
return tools, mcpClient, nil
}
// createTransport builds the appropriate mcp-go transport based
// on the server's configured transport type.
func createTransport(
cfg database.MCPServerConfig,
headers map[string]string,
) (transport.Interface, error) {
// Each connection gets its own HTTP client with a dedicated
// transport so that httptest.Server.Close() (which calls
// CloseIdleConnections on http.DefaultTransport) does not
// disrupt unrelated connections during parallel tests.
var httpClient *http.Client
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
httpClient = &http.Client{Transport: dt.Clone()}
} else {
httpClient = &http.Client{}
}
switch cfg.Transport {
case "sse":
return transport.NewSSE(
cfg.Url,
transport.WithHeaders(headers),
transport.WithHTTPClient(httpClient),
)
case "", "streamable_http":
// Default to streamable HTTP, the newer transport.
return transport.NewStreamableHTTP(
cfg.Url,
transport.WithHTTPHeaders(headers),
transport.WithHTTPBasicClient(httpClient),
)
default:
return nil, xerrors.Errorf(
"unsupported transport %q", cfg.Transport,
)
}
}
// buildAuthHeaders constructs HTTP headers for authenticating
// with the MCP server based on the configured auth type.
func buildAuthHeaders(
ctx context.Context,
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
) map[string]string {
// Using map[string]string rather than http.Header because
// the mcp-go transport options accept map[string]string.
// MCP servers typically don't require multi-valued headers.
headers := make(map[string]string)
switch cfg.AuthType {
case "oauth2":
tok, ok := tokensByConfigID[cfg.ID]
if !ok {
logger.Warn(ctx,
"no oauth2 token found for MCP server",
slog.F("server_slug", cfg.Slug),
)
break
}
if tok.Expiry.Valid && tok.Expiry.Time.Before(time.Now()) {
logger.Warn(ctx,
"oauth2 token for MCP server is expired",
slog.F("server_slug", cfg.Slug),
slog.F("expired_at", tok.Expiry.Time),
)
}
if tok.AccessToken == "" {
logger.Warn(ctx,
"oauth2 token record has empty access token",
slog.F("server_slug", cfg.Slug),
)
break
}
tokenType := tok.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
// RFC 6750 says the scheme is case-insensitive, but
// some servers (e.g. Linear) reject lowercase
// "bearer". Normalize to the canonical form.
if strings.EqualFold(tokenType, "bearer") {
tokenType = "Bearer"
}
headers["Authorization"] = tokenType + " " + tok.AccessToken
case "api_key":
if cfg.APIKeyHeader != "" && cfg.APIKeyValue != "" {
headers[cfg.APIKeyHeader] = cfg.APIKeyValue
}
case "custom_headers":
if cfg.CustomHeaders != "" {
var custom map[string]string
if err := json.Unmarshal(
[]byte(cfg.CustomHeaders), &custom,
); err != nil {
logger.Warn(ctx,
"failed to parse custom headers JSON",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
} else {
for k, v := range custom {
headers[k] = v
}
}
}
// Overlay user-supplied values for keys the admin marked as
// user-set. Validation guarantees these are disjoint from
// CustomHeaders, but the overlay is well-defined either way.
if len(cfg.CustomHeadersUserKeys) > 0 {
row, ok := userHeaderValuesByConfigID[cfg.ID]
if !ok {
// Normal state: this user has never saved values for
// this server. The MCP call will proceed without the
// user-set headers and likely fail at the remote end,
// which is the expected signal for the UI to prompt
// the user. Debug-level keeps this off the noise floor.
logger.Debug(ctx,
"no user header values for MCP server",
slog.F("server_slug", cfg.Slug),
)
break
}
var user map[string]string
if err := json.Unmarshal(
[]byte(row.HeaderValues), &user,
); err != nil {
logger.Warn(ctx,
"failed to parse user header values JSON",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
break
}
for _, k := range cfg.CustomHeadersUserKeys {
// Case-insensitive lookup so a case-only admin rename
// does not silently drop the user's stored value.
v, has := mcpHeaderValueForKey(user, k)
if has && v != "" {
headers[k] = v
}
}
}
case "user_oidc":
// Forward the calling user's OIDC access token from
// user_links as Authorization: Bearer <token>. The token
// source is responsible for refreshing tokens that are
// expired or close to expiring before returning them.
if oidcSrc == nil || userID == uuid.Nil {
logger.Warn(ctx,
"user_oidc auth requested but no token source available",
slog.F("server_slug", cfg.Slug),
)
break
}
token, err := oidcSrc.OIDCAccessToken(ctx, userID)
if err != nil {
logger.Warn(ctx,
"failed to obtain user OIDC token for MCP server",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
break
}
if token == "" {
// The user has no OIDC link, or a non-fatal refresh
// failure occurred. Fall through with no header and let
// the upstream MCP server decide how to respond
// (typically 401). Logged at debug so password and
// GitHub users don't generate noise for every chat turn.
logger.Debug(ctx,
"no user OIDC token available for MCP server",
slog.F("server_slug", cfg.Slug),
)
break
}
headers["Authorization"] = "Bearer " + token
case "none", "":
// No auth headers needed.
}
return headers
}
// mcpHeaderValueForKey returns the stored value for key using a
// case-insensitive match. The stored user-header map preserves the
// admin's casing at write time, so a later case-only rename of a
// user-set key would otherwise orphan the stored value until the
// user re-saves it.
func mcpHeaderValueForKey(stored map[string]string, key string) (string, bool) {
if v, ok := stored[key]; ok {
return v, true
}
for k, v := range stored {
if strings.EqualFold(k, key) {
return v, true
}
}
return "", false
}
// isToolAllowed checks a tool name against the allow and deny
// lists. When the allow list is non-empty only tools in it are
// permitted and the deny list is ignored. When the allow list
// is empty and the deny list is non-empty, tools in the deny
// list are rejected. Both lists use exact string matching
// against the original (non-prefixed) tool name.
func isToolAllowed(
toolName string,
allowList []string,
denyList []string,
) bool {
if len(allowList) > 0 {
for _, allowed := range allowList {
if allowed == toolName {
return true
}
}
// Allow list is set but the tool isn't in it.
return false
}
for _, denied := range denyList {
if denied == toolName {
return false
}
}
return true
}
// RedactURL strips userinfo and query parameters from a URL
// to avoid logging embedded credentials. Query params are
// removed because API keys are sometimes passed as
// ?api_key=sk-... in server URLs.
func RedactURL(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return rawURL
}
u.User = nil
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
// redactErrorURL rewrites URLs in an error string to strip
// credentials. Go's net/http embeds the full request URL in
// *url.Error messages, which can leak userinfo.
func redactErrorURL(err error) string {
if err == nil {
return ""
}
var urlErr *url.Error
if errors.As(err, &urlErr) {
urlErr.URL = RedactURL(urlErr.URL)
return urlErr.Error()
}
return err.Error()
}
// MCPToolIdentifier is implemented by tools that originate from
// an MCP server config and can report the config's database ID.
type MCPToolIdentifier interface {
MCPServerConfigID() uuid.UUID
}
// mcpToolWrapper adapts a single MCP tool into a
// fantasy.AgentTool. It stores the prefixed name for Info() but
// strips the prefix when forwarding calls to the remote server.
type mcpToolWrapper struct {
configID uuid.UUID
prefixedName string
originalName string
description string
parameters map[string]any
required []string
modelIntent bool
client *client.Client
providerOptions fantasy.ProviderOptions
}
// MCPServerConfigID returns the database ID of the MCP server
// config that this tool originates from.
func (t *mcpToolWrapper) MCPServerConfigID() uuid.UUID {
return t.configID
}
// newMCPTool creates an mcpToolWrapper from an mcp.Tool
// discovered on a remote server.
func newMCPTool(
configID uuid.UUID,
serverSlug string,
tool mcp.Tool,
mcpClient *client.Client,
modelIntent bool,
) *mcpToolWrapper {
return &mcpToolWrapper{
configID: configID,
prefixedName: serverSlug + toolNameSep + tool.Name,
originalName: tool.Name,
description: tool.Description,
parameters: tool.InputSchema.Properties,
required: tool.InputSchema.Required,
modelIntent: modelIntent,
client: mcpClient,
}
}
func (t *mcpToolWrapper) Info() fantasy.ToolInfo {
required := t.required
if required == nil {
required = []string{}
}
if !t.modelIntent {
return fantasy.ToolInfo{
Name: t.prefixedName,
Description: t.description,
Parameters: t.parameters,
Required: required,
Parallel: true,
}
}
// Wrap original parameters under "properties" and add
// "model_intent" so the LLM provides a human-readable
// description of each tool call.
wrapped := map[string]any{
"model_intent": map[string]any{
"type": "string",
"description": "A short, natural-language, present-participle " +
"phrase describing why you are calling this tool. " +
"This is shown to the user as a status label while " +
"the tool runs. Use plain English with no underscores " +
"or technical jargon. Keep it under 100 characters. " +
"Good examples: \"Reading the authentication module\", " +
"\"Searching for configuration files\", " +
"\"Creating a new workspace\".",
},
"properties": map[string]any{
"type": "object",
"properties": t.parameters,
"required": required,
},
}
return fantasy.ToolInfo{
Name: t.prefixedName,
Description: t.description,
Parameters: wrapped,
Required: []string{"model_intent", "properties"},
Parallel: true,
}
}
func (t *mcpToolWrapper) Run(
ctx context.Context,
params fantasy.ToolCall,
) (fantasy.ToolResponse, error) {
input := params.Input
if t.modelIntent {
input = unwrapModelIntent(input)
}
var args map[string]any
if input != "" {
if err := json.Unmarshal(
[]byte(input), &args,
); err != nil {
return fantasy.NewTextErrorResponse(
"invalid JSON input: " + err.Error(),
), nil
}
}
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
defer cancel()
result, err := t.client.CallTool(
callCtx,
mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: t.originalName,
Arguments: args,
},
},
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return convertCallResult(result), nil
}
func (t *mcpToolWrapper) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
func (t *mcpToolWrapper) SetProviderOptions(
opts fantasy.ProviderOptions,
) {
t.providerOptions = opts
}
// unwrapModelIntent strips the model_intent wrapper from tool
// call input so the remote MCP server receives only the original
// arguments. It handles three shapes the model may produce:
//
// 1. { model_intent, properties: {...} } — correct format
// 2. { model_intent, key: val, ... } — flat, no properties wrapper
// 3. Anything else — returned as-is
func unwrapModelIntent(input string) string {
var parsed map[string]any
if err := json.Unmarshal([]byte(input), &parsed); err != nil {
return input
}
delete(parsed, "model_intent")
// Case 1: correct { model_intent, properties: {...} } format.
if props, ok := parsed["properties"]; ok {
if b, err := json.Marshal(props); err == nil {
return string(b)
}
}
// Case 2: flat { model_intent, key: val, ... } without wrapper.
if b, err := json.Marshal(parsed); err == nil {
return string(b)
}
return input
}
// convertCallResult translates an MCP CallToolResult into a
// fantasy.ToolResponse. The fantasy response model supports a
// single content type per response, so we prioritize text. All
// text items are collected first. Binary items (image, audio,
// or embedded blob) are only returned when no text content is
// available.
func convertCallResult(
result *mcp.CallToolResult,
) fantasy.ToolResponse {
if result == nil {
return fantasy.NewTextResponse("")
}
var (
textParts []string
binaryResult *fantasy.ToolResponse
)
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
textParts = append(textParts, strings.ToValidUTF8(c.Text, "\uFFFD"))
case mcp.ImageContent:
data, err := base64.StdEncoding.DecodeString(
c.Data,
)
if err != nil {
textParts = append(textParts,
"[image decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
r := fantasy.ToolResponse{
Type: "image",
Data: data,
MediaType: c.MIMEType,
IsError: result.IsError,
}
binaryResult = &r
}
case mcp.AudioContent:
data, err := base64.StdEncoding.DecodeString(
c.Data,
)
if err != nil {
textParts = append(textParts,
"[audio decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
r := fantasy.ToolResponse{
Type: "media",
Data: data,
MediaType: c.MIMEType,
IsError: result.IsError,
}
binaryResult = &r
}
case mcp.EmbeddedResource:
// Embedded resources wrap either text or blob
// content from an MCP resource. We handle each
// variant so the LLM receives the content
// regardless of form.
switch r := c.Resource.(type) {
case mcp.TextResourceContents:
textParts = append(textParts, strings.ToValidUTF8(r.Text, "\uFFFD"))
case mcp.BlobResourceContents:
data, err := base64.StdEncoding.DecodeString(
r.Blob,
)
if err != nil {
textParts = append(textParts,
"[blob decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
blobType := "media"
if strings.HasPrefix(r.MIMEType, "image/") {
blobType = "image"
}
res := fantasy.ToolResponse{
Type: blobType,
Data: data,
MediaType: r.MIMEType,
IsError: result.IsError,
}
binaryResult = &res
}
default:
textParts = append(textParts,
fmt.Sprintf(
"[unsupported embedded resource type: %T]",
c.Resource,
),
)
}
case mcp.ResourceLink:
// Resource links point to content the LLM can
// reference by URI. Surface the URI so the model
// can use it in follow-ups.
label := c.URI
if c.Name != "" {
label = fmt.Sprintf("%s (%s)", c.Name, c.URI)
}
if c.Description != "" {
label += ": " + c.Description
}
textParts = append(textParts,
fmt.Sprintf("[resource: %s]", label),
)
default:
textParts = append(textParts,
fmt.Sprintf("[unsupported content type: %T]", c),
)
}
}
// If structured content is present, marshal it to JSON and
// append as a text part so the data is preserved for the LLM.
if result.StructuredContent != nil {
data, err := json.Marshal(result.StructuredContent)
if err != nil {
textParts = append(textParts,
"[structured content marshal error: "+
err.Error()+"]",
)
} else {
textParts = append(textParts, string(data))
}
}
// Prefer text content. Only fall back to binary when no
// text was collected.
if len(textParts) > 0 {
resp := fantasy.NewTextResponse(
strings.Join(textParts, "\n"),
)
resp.IsError = result.IsError
return resp
}
if binaryResult != nil {
return *binaryResult
}
return fantasy.NewTextResponse("")
}
// RefreshResult contains the outcome of an OAuth2 token refresh
// attempt.
type RefreshResult struct {
// AccessToken is the new (or unchanged) access token.
AccessToken string
// RefreshToken is the new (or preserved original) refresh
// token. Providers that don't rotate refresh tokens return
// an empty value; in that case the original is kept.
RefreshToken string
// TokenType is the token type (usually "Bearer").
TokenType string
// Expiry is the new token expiry. Zero value means no expiry
// was provided by the provider.
Expiry time.Time
// Refreshed is true when the access token actually changed,
// meaning a refresh occurred. When false the token was still
// valid and no network call was made.
Refreshed bool
}
// RefreshOAuth2Token checks whether the given MCP user token is
// expired (or within 10 seconds of expiry) and refreshes it using
// the OAuth2 credentials from the server config. If the token is
// still valid, no network call is made and Refreshed is false.
//
// The caller is responsible for persisting the result when
// Refreshed is true.
func RefreshOAuth2Token(
ctx context.Context,
cfg database.MCPServerConfig,
tok database.MCPServerUserToken,
) (RefreshResult, error) {
oauth2Cfg := &oauth2.Config{
ClientID: cfg.OAuth2ClientID,
ClientSecret: cfg.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: cfg.OAuth2TokenURL,
},
}
oldToken := &oauth2.Token{
AccessToken: tok.AccessToken,
RefreshToken: tok.RefreshToken,
TokenType: tok.TokenType,
}
if tok.Expiry.Valid {
oldToken.Expiry = tok.Expiry.Time
}
// Cap the refresh HTTP call so a stalled token endpoint
// cannot block the entire MCP connection phase. The timeout
// matches connectTimeout used for MCP server connections.
refreshCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
// TokenSource automatically refreshes expired tokens. It
// uses a 10-second expiry window, so tokens about to expire
// are also refreshed proactively.
newToken, err := oauth2Cfg.TokenSource(refreshCtx, oldToken).Token()
if err != nil {
return RefreshResult{}, xerrors.Errorf("refresh oauth2 token: %w", err)
}
refreshed := newToken.AccessToken != tok.AccessToken
// Preserve the old refresh token when the provider doesn't
// rotate (returns empty).
refreshToken := cmp.Or(newToken.RefreshToken, tok.RefreshToken)
return RefreshResult{
AccessToken: newToken.AccessToken,
RefreshToken: refreshToken,
TokenType: newToken.TokenType,
Expiry: newToken.Expiry,
Refreshed: refreshed,
}, nil
}