feat(coderd/x/chatd): overlay user-set custom_headers at runtime

Threads the per-user custom_headers values stored on
mcp_server_user_header_values through the chatd MCP client so users
who provided a value for an admin-marked CustomHeadersUserKey see it
mixed into the outgoing request alongside the admin-static headers.

Changes:

- mcpclient.ConnectAll grows a fourth indexed input,
  []database.McpServerUserHeaderValue, which buildAuthHeaders
  consults inside the custom_headers branch to overlay per-user
  values on top of admin static headers, scoped to
  cfg.CustomHeadersUserKeys.
- chatd loads the user's stored header values via
  GetMCPServerUserHeaderValuesByUserID alongside the existing
  GetMCPServerUserTokensByUserID call and threads them into
  ConnectAll. A missing row is non-fatal: admin headers still
  ship, user-keyed headers are simply absent and a warning is
  logged.
- mcpclient.go inlines its own DefaultTransport clone for test
  isolation, replacing the standalone helper in mcphttpclient.go,
  which is removed.

Stack: 4/6 (chatd runtime overlay)
This commit is contained in:
Steven Masley
2026-06-01 15:02:34 +00:00
parent 94939e2fbb
commit d2f9ad783e
5 changed files with 277 additions and 73 deletions
+20 -3
View File
@@ -7130,8 +7130,9 @@ func (p *Server) runChat(
// resolution. These queries have no dependencies on each other and all
// hit different tables.
var (
mcpConfigs []database.MCPServerConfig
mcpTokens []database.MCPServerUserToken
mcpConfigs []database.MCPServerConfig
mcpTokens []database.MCPServerUserToken
mcpHeaderValues []database.McpServerUserHeaderValue
)
var g errgroup.Group
g.Go(func() error {
@@ -7179,6 +7180,22 @@ func (p *Server) runChat(
}
return nil
})
g.Go(func() error {
var err error
// If header-values loading fails, ConnectAll proceeds
// without user values; custom_headers servers that
// require user-set keys will be missing those headers.
mcpHeaderValues, err = p.db.GetMCPServerUserHeaderValuesByUserID(
ctx, chat.OwnerID,
)
if err != nil {
logger.Warn(ctx,
"failed to load MCP user header values",
slog.Error(err),
)
}
return nil
})
}
if err := g.Wait(); err != nil {
return result, err
@@ -7493,7 +7510,7 @@ func (p *Server) runChat(
// Refresh expired OAuth2 tokens before connecting.
mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens)
mcpTools, mcpCleanup = mcpclient.ConnectAll(
ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource,
ctx, logger, mcpConnectConfigs, mcpTokens, mcpHeaderValues, chat.OwnerID, p.oidcTokenSource,
chatprovider.CoderHeaders(chat),
)
return nil
@@ -65,7 +65,7 @@ func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) {
}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
@@ -115,7 +115,7 @@ func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) {
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
@@ -158,7 +158,7 @@ func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) {
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
@@ -204,7 +204,7 @@ func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) {
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
@@ -257,6 +257,7 @@ func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) {
ctx, logger,
[]database.MCPServerConfig{cfg},
[]database.MCPServerUserToken{token},
nil,
uuid.Nil, nil,
coderHeaders,
)
@@ -306,7 +307,7 @@ func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) {
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
+87 -15
View File
@@ -72,6 +72,7 @@ func ConnectAll(
logger slog.Logger,
configs []database.MCPServerConfig,
tokens []database.MCPServerUserToken,
userHeaderValues []database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
coderHeaders map[string]string,
@@ -85,6 +86,14 @@ func ConnectAll(
tokensByConfigID[tok.MCPServerConfigID] = tok
}
// Same indexing for the calling user's custom_headers values.
userHeaderValuesByConfigID := make(
map[uuid.UUID]database.McpServerUserHeaderValue, len(userHeaderValues),
)
for _, hv := range userHeaderValues {
userHeaderValuesByConfigID[hv.MCPServerConfigID] = hv
}
var (
mu sync.Mutex
clients []*client.Client
@@ -110,7 +119,7 @@ func ConnectAll(
eg.Go(func() error {
serverTools, mcpClient, connectErr := connectOne(
ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, coderHeaders,
ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc, coderHeaders,
)
if connectErr != nil {
logger.Warn(ctx,
@@ -174,11 +183,12 @@ func connectOne(
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
coderHeaders map[string]string,
) ([]fantasy.AgentTool, *client.Client, error) {
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc)
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc)
// When opted-in, merge Coder identity headers BEFORE the
// transport is created so any auth header already set above
@@ -285,24 +295,31 @@ func createTransport(
cfg database.MCPServerConfig,
headers map[string]string,
) (transport.Interface, error) {
httpClient := mcpHTTPClient()
// Each connection gets its own HTTP client with a dedicated
// transport so that httptest.Server.Close() (which calls
// CloseIdleConnections on http.DefaultTransport) does not
// disrupt unrelated connections during parallel tests.
var httpClient *http.Client
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
httpClient = &http.Client{Transport: dt.Clone()}
} else {
httpClient = &http.Client{}
}
switch cfg.Transport {
case "sse":
var opts []transport.ClientOption
opts = append(opts, transport.WithHeaders(headers))
if httpClient != nil {
opts = append(opts, transport.WithHTTPClient(httpClient))
}
return transport.NewSSE(cfg.Url, opts...)
return transport.NewSSE(
cfg.Url,
transport.WithHeaders(headers),
transport.WithHTTPClient(httpClient),
)
case "", "streamable_http":
// Default to streamable HTTP, the newer transport.
var opts []transport.StreamableHTTPCOption
opts = append(opts, transport.WithHTTPHeaders(headers))
if httpClient != nil {
opts = append(opts, transport.WithHTTPBasicClient(httpClient))
}
return transport.NewStreamableHTTP(cfg.Url, opts...)
return transport.NewStreamableHTTP(
cfg.Url,
transport.WithHTTPHeaders(headers),
transport.WithHTTPBasicClient(httpClient),
)
default:
return nil, xerrors.Errorf(
"unsupported transport %q", cfg.Transport,
@@ -317,6 +334,7 @@ func buildAuthHeaders(
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
userID uuid.UUID,
oidcSrc UserOIDCTokenSource,
) map[string]string {
@@ -381,6 +399,43 @@ func buildAuthHeaders(
}
}
}
// Overlay user-supplied values for keys the admin marked as
// user-set. Validation guarantees these are disjoint from
// CustomHeaders, but the overlay is well-defined either way.
if len(cfg.CustomHeadersUserKeys) > 0 {
row, ok := userHeaderValuesByConfigID[cfg.ID]
if !ok {
// Normal state: this user has never saved values for
// this server. The MCP call will proceed without the
// user-set headers and likely fail at the remote end,
// which is the expected signal for the UI to prompt
// the user. Debug-level keeps this off the noise floor.
logger.Debug(ctx,
"no user header values for MCP server",
slog.F("server_slug", cfg.Slug),
)
break
}
var user map[string]string
if err := json.Unmarshal(
[]byte(row.HeaderValues), &user,
); err != nil {
logger.Warn(ctx,
"failed to parse user header values JSON",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
break
}
for _, k := range cfg.CustomHeadersUserKeys {
// Case-insensitive lookup so a case-only admin rename
// does not silently drop the user's stored value.
v, has := mcpHeaderValueForKey(user, k)
if has && v != "" {
headers[k] = v
}
}
}
case "user_oidc":
// Forward the calling user's OIDC access token from
// user_links as Authorization: Bearer <token>. The token
@@ -422,6 +477,23 @@ func buildAuthHeaders(
return headers
}
// mcpHeaderValueForKey returns the stored value for key using a
// case-insensitive match. The stored user-header map preserves the
// admin's casing at write time, so a later case-only rename of a
// user-set key would otherwise orphan the stored value until the
// user re-saves it.
func mcpHeaderValueForKey(stored map[string]string, key string) (string, bool) {
if v, ok := stored[key]; ok {
return v, true
}
for k, v := range stored {
if strings.EqualFold(k, key) {
return v, true
}
}
return "", false
}
// isToolAllowed checks a tool name against the allow and deny
// lists. When the allow list is non-empty only tools in it are
// permitted and the deny list is ignored. When the allow list
+164 -25
View File
@@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) {
ts := newTestMCPServer(t, echoTool(), greetTool())
cfg := makeConfig("myserver", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
// Two tools should be discovered, namespaced with the server slug.
@@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) {
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) {
// Only allow the "echo" tool.
cfg.ToolAllowList = []string{"echo"}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) {
// Deny the "greet" tool, so only "echo" remains.
cfg.ToolDenyList = []string{"greet"}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) {
cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist")
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
assert.Empty(t, tools, "no tools should be returned for an unreachable server")
@@ -200,6 +200,7 @@ func TestConnectAll_MultipleServers(t *testing.T) {
ctx, logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -227,6 +228,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
logger,
[]database.MCPServerConfig{cfg},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -256,6 +258,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
makeConfig("srv2", ts2.URL),
},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -286,6 +289,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
makeConfig("aaa", other.URL),
},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -320,6 +324,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -385,6 +390,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) {
ctx, logger,
[]database.MCPServerConfig{cfg},
[]database.MCPServerUserToken{token},
nil,
uuid.Nil, nil,
nil,
)
@@ -441,7 +447,7 @@ func TestConnectAll_DisabledServer(t *testing.T) {
cfg := makeConfig("disabled", ts.URL)
cfg.Enabled = false
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
assert.Empty(t, tools)
}
@@ -456,7 +462,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) {
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -481,7 +487,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) {
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -523,7 +529,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) {
ts := newTestMCPServer(t, noRequiredTool)
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -575,6 +581,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) {
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -633,6 +640,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) {
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -671,6 +679,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -682,6 +691,134 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
assert.Equal(t, "badjson__echo", tools[0].Info().Name)
}
// TestConnectAll_CustomHeadersUserKeysOverlay verifies that
// custom_headers auth overlays per-user values onto the admin-set
// headers based on cfg.CustomHeadersUserKeys.
func TestConnectAll_CustomHeadersUserKeysOverlay(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
var (
mu sync.Mutex
seenOrg []string
seenUser []string
seenWS []string
)
srv := mcpserver.NewMCPServer("overlay-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("check",
mcp.WithDescription("Returns the auth headers"),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
mu.Lock()
seenOrg = append(seenOrg, req.Header.Get("X-Org-ID"))
seenUser = append(seenUser, req.Header.Get("X-User-Token"))
seenWS = append(seenWS, req.Header.Get("X-Workspace"))
mu.Unlock()
return mcp.NewToolResultText("ok"), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
cfg := makeConfig("overlay", ts.URL)
cfg.AuthType = "custom_headers"
cfg.CustomHeaders = `{"X-Org-ID":"acme"}`
cfg.CustomHeadersUserKeys = []string{"X-User-Token", "X-Workspace"}
userHeaderValues := []database.McpServerUserHeaderValue{{
MCPServerConfigID: cfg.ID,
UserID: uuid.New(),
HeaderValues: `{"X-User-Token":"jwt-abc","X-Workspace":"main"}`,
}}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
userHeaderValues,
uuid.Nil, nil,
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-overlay",
Name: "overlay__check",
Input: "{}",
})
require.NoError(t, err)
assert.False(t, resp.IsError)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, seenOrg)
assert.Equal(t, "acme", seenOrg[len(seenOrg)-1], "admin header preserved")
assert.Equal(t, "jwt-abc", seenUser[len(seenUser)-1], "user-set header overlays")
assert.Equal(t, "main", seenWS[len(seenWS)-1], "user-set header overlays")
}
// TestConnectAll_CustomHeadersUserKeysMissingRow verifies that when
// CustomHeadersUserKeys is non-empty but no user header values row is
// present, admin headers still go out and user keys are simply absent.
func TestConnectAll_CustomHeadersUserKeysMissingRow(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
var (
mu sync.Mutex
seenOrg []string
seenUser []string
)
srv := mcpserver.NewMCPServer("missing-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("check",
mcp.WithDescription("Returns the auth headers"),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
mu.Lock()
seenOrg = append(seenOrg, req.Header.Get("X-Org-ID"))
seenUser = append(seenUser, req.Header.Get("X-User-Token"))
mu.Unlock()
return mcp.NewToolResultText("ok"), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
cfg := makeConfig("missing", ts.URL)
cfg.AuthType = "custom_headers"
cfg.CustomHeaders = `{"X-Org-ID":"acme"}`
cfg.CustomHeadersUserKeys = []string{"X-User-Token"}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
nil, // no userHeaderValues at all
uuid.Nil, nil,
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-missing",
Name: "missing__check",
Input: "{}",
})
require.NoError(t, err)
assert.False(t, resp.IsError)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, seenOrg)
assert.Equal(t, "acme", seenOrg[len(seenOrg)-1])
assert.Equal(t, "", seenUser[len(seenUser)-1])
}
// staticOIDCSource implements mcpclient.UserOIDCTokenSource for tests
// without requiring a real OIDC provider or database round-trip.
type staticOIDCSource struct {
@@ -730,7 +867,7 @@ func TestConnectAll_UserOIDCAuth(t *testing.T) {
src := staticOIDCSource{token: "fake-oidc-token"}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
userID, src, nil,
)
t.Cleanup(cleanup)
@@ -789,7 +926,7 @@ func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) {
src := staticOIDCSource{token: "", err: nil}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
uuid.New(), src, nil,
)
t.Cleanup(cleanup)
@@ -825,7 +962,7 @@ func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) {
cfg.AuthType = "user_oidc"
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
uuid.New(), nil, nil,
)
t.Cleanup(cleanup)
@@ -854,6 +991,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) {
ctx, logger,
[]database.MCPServerConfig{cfg1, cfg2, cfg3},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -916,7 +1054,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) {
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
// The server accepts any auth, so the tool is still discovered
@@ -949,7 +1087,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) {
TokenType: "Bearer",
}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
// Tool is still discovered (server doesn't require auth), but
@@ -979,7 +1117,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) {
Enabled: true,
}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1025,6 +1163,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) {
ctx, logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
nil,
uuid.Nil, nil,
nil,
)
@@ -1083,7 +1222,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) {
t.Cleanup(ts.Close)
cfg := makeConfig("embed-txt", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1150,7 +1289,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) {
t.Cleanup(ts.Close)
cfg := makeConfig("embed-blob", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1230,7 +1369,7 @@ func TestConnectAll_ResourceLink(t *testing.T) {
t.Cleanup(ts.Close)
cfg := makeConfig("res-link", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1274,7 +1413,7 @@ func TestConnectAll_CallToolError(t *testing.T) {
t.Cleanup(ts.Close)
cfg := makeConfig("err-srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1298,7 +1437,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) {
cfg := makeConfig("intent-srv", ts.URL)
cfg.ModelIntent = true
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1334,7 +1473,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) {
cfg := makeConfig("no-intent", ts.URL)
cfg.ModelIntent = false
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1357,7 +1496,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) {
cfg := makeConfig("unwrap-srv", ts.URL)
cfg.ModelIntent = true
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1382,7 +1521,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) {
cfg := makeConfig("flat-srv", ts.URL)
cfg.ModelIntent = true
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1407,7 +1546,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) {
cfg := makeConfig("pass-srv", ts.URL)
cfg.ModelIntent = false
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
@@ -1432,7 +1571,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) {
cfg := makeConfig("bad-srv", ts.URL)
cfg.ModelIntent = true
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
-25
View File
@@ -1,25 +0,0 @@
package mcpclient
import (
"net/http"
"testing"
)
// mcpHTTPClient returns an isolated *http.Client when running
// inside tests, or nil for production. During tests,
// httptest.Server.Close() calls
// http.DefaultTransport.CloseIdleConnections(), which disrupts
// any MCP client sharing that transport. When DefaultTransport
// is a *http.Transport it is cloned; otherwise a minimal
// transport with ProxyFromEnvironment is created as a fallback.
func mcpHTTPClient() *http.Client {
if !testing.Testing() {
return nil
}
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
return &http.Client{Transport: dt.Clone()}
}
return &http.Client{Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
}}
}