From d2f9ad783e6748c987c6e7a8faa7d596a3fa438d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 1 Jun 2026 15:02:34 +0000 Subject: [PATCH] feat(coderd/x/chatd): overlay user-set custom_headers at runtime Threads the per-user custom_headers values stored on mcp_server_user_header_values through the chatd MCP client so users who provided a value for an admin-marked CustomHeadersUserKey see it mixed into the outgoing request alongside the admin-static headers. Changes: - mcpclient.ConnectAll grows a fourth indexed input, []database.McpServerUserHeaderValue, which buildAuthHeaders consults inside the custom_headers branch to overlay per-user values on top of admin static headers, scoped to cfg.CustomHeadersUserKeys. - chatd loads the user's stored header values via GetMCPServerUserHeaderValuesByUserID alongside the existing GetMCPServerUserTokensByUserID call and threads them into ConnectAll. A missing row is non-fatal: admin headers still ship, user-keyed headers are simply absent and a warning is logged. - mcpclient.go inlines its own DefaultTransport clone for test isolation, replacing the standalone helper in mcphttpclient.go, which is removed. Stack: 4/6 (chatd runtime overlay) --- coderd/x/chatd/chatd.go | 23 ++- .../x/chatd/mcpclient/coder_headers_test.go | 11 +- coderd/x/chatd/mcpclient/mcpclient.go | 102 ++++++++-- coderd/x/chatd/mcpclient/mcpclient_test.go | 189 +++++++++++++++--- coderd/x/chatd/mcpclient/mcphttpclient.go | 25 --- 5 files changed, 277 insertions(+), 73 deletions(-) delete mode 100644 coderd/x/chatd/mcpclient/mcphttpclient.go diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 013567e3e0..ec17726e76 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -7130,8 +7130,9 @@ func (p *Server) runChat( // resolution. These queries have no dependencies on each other and all // hit different tables. var ( - mcpConfigs []database.MCPServerConfig - mcpTokens []database.MCPServerUserToken + mcpConfigs []database.MCPServerConfig + mcpTokens []database.MCPServerUserToken + mcpHeaderValues []database.McpServerUserHeaderValue ) var g errgroup.Group g.Go(func() error { @@ -7179,6 +7180,22 @@ func (p *Server) runChat( } return nil }) + g.Go(func() error { + var err error + // If header-values loading fails, ConnectAll proceeds + // without user values; custom_headers servers that + // require user-set keys will be missing those headers. + mcpHeaderValues, err = p.db.GetMCPServerUserHeaderValuesByUserID( + ctx, chat.OwnerID, + ) + if err != nil { + logger.Warn(ctx, + "failed to load MCP user header values", + slog.Error(err), + ) + } + return nil + }) } if err := g.Wait(); err != nil { return result, err @@ -7493,7 +7510,7 @@ func (p *Server) runChat( // Refresh expired OAuth2 tokens before connecting. mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens) mcpTools, mcpCleanup = mcpclient.ConnectAll( - ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource, + ctx, logger, mcpConnectConfigs, mcpTokens, mcpHeaderValues, chat.OwnerID, p.oidcTokenSource, chatprovider.CoderHeaders(chat), ) return nil diff --git a/coderd/x/chatd/mcpclient/coder_headers_test.go b/coderd/x/chatd/mcpclient/coder_headers_test.go index f90a031d5a..311e814c76 100644 --- a/coderd/x/chatd/mcpclient/coder_headers_test.go +++ b/coderd/x/chatd/mcpclient/coder_headers_test.go @@ -65,7 +65,7 @@ func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) { } tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, coderHeaders, ) t.Cleanup(cleanup) @@ -115,7 +115,7 @@ func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) { }) tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, coderHeaders, ) t.Cleanup(cleanup) @@ -158,7 +158,7 @@ func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) { }) tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, coderHeaders, ) t.Cleanup(cleanup) @@ -204,7 +204,7 @@ func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) { }) tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, coderHeaders, ) t.Cleanup(cleanup) @@ -257,6 +257,7 @@ func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, + nil, uuid.Nil, nil, coderHeaders, ) @@ -306,7 +307,7 @@ func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) { }) tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, coderHeaders, ) t.Cleanup(cleanup) diff --git a/coderd/x/chatd/mcpclient/mcpclient.go b/coderd/x/chatd/mcpclient/mcpclient.go index cb7e0322c2..422d93d30e 100644 --- a/coderd/x/chatd/mcpclient/mcpclient.go +++ b/coderd/x/chatd/mcpclient/mcpclient.go @@ -72,6 +72,7 @@ func ConnectAll( logger slog.Logger, configs []database.MCPServerConfig, tokens []database.MCPServerUserToken, + userHeaderValues []database.McpServerUserHeaderValue, userID uuid.UUID, oidcSrc UserOIDCTokenSource, coderHeaders map[string]string, @@ -85,6 +86,14 @@ func ConnectAll( tokensByConfigID[tok.MCPServerConfigID] = tok } + // Same indexing for the calling user's custom_headers values. + userHeaderValuesByConfigID := make( + map[uuid.UUID]database.McpServerUserHeaderValue, len(userHeaderValues), + ) + for _, hv := range userHeaderValues { + userHeaderValuesByConfigID[hv.MCPServerConfigID] = hv + } + var ( mu sync.Mutex clients []*client.Client @@ -110,7 +119,7 @@ func ConnectAll( eg.Go(func() error { serverTools, mcpClient, connectErr := connectOne( - ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, coderHeaders, + ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc, coderHeaders, ) if connectErr != nil { logger.Warn(ctx, @@ -174,11 +183,12 @@ func connectOne( logger slog.Logger, cfg database.MCPServerConfig, tokensByConfigID map[uuid.UUID]database.MCPServerUserToken, + userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue, userID uuid.UUID, oidcSrc UserOIDCTokenSource, coderHeaders map[string]string, ) ([]fantasy.AgentTool, *client.Client, error) { - headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc) + headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc) // When opted-in, merge Coder identity headers BEFORE the // transport is created so any auth header already set above @@ -285,24 +295,31 @@ func createTransport( cfg database.MCPServerConfig, headers map[string]string, ) (transport.Interface, error) { - httpClient := mcpHTTPClient() + // Each connection gets its own HTTP client with a dedicated + // transport so that httptest.Server.Close() (which calls + // CloseIdleConnections on http.DefaultTransport) does not + // disrupt unrelated connections during parallel tests. + var httpClient *http.Client + if dt, ok := http.DefaultTransport.(*http.Transport); ok { + httpClient = &http.Client{Transport: dt.Clone()} + } else { + httpClient = &http.Client{} + } switch cfg.Transport { case "sse": - var opts []transport.ClientOption - opts = append(opts, transport.WithHeaders(headers)) - if httpClient != nil { - opts = append(opts, transport.WithHTTPClient(httpClient)) - } - return transport.NewSSE(cfg.Url, opts...) + return transport.NewSSE( + cfg.Url, + transport.WithHeaders(headers), + transport.WithHTTPClient(httpClient), + ) case "", "streamable_http": // Default to streamable HTTP, the newer transport. - var opts []transport.StreamableHTTPCOption - opts = append(opts, transport.WithHTTPHeaders(headers)) - if httpClient != nil { - opts = append(opts, transport.WithHTTPBasicClient(httpClient)) - } - return transport.NewStreamableHTTP(cfg.Url, opts...) + return transport.NewStreamableHTTP( + cfg.Url, + transport.WithHTTPHeaders(headers), + transport.WithHTTPBasicClient(httpClient), + ) default: return nil, xerrors.Errorf( "unsupported transport %q", cfg.Transport, @@ -317,6 +334,7 @@ func buildAuthHeaders( logger slog.Logger, cfg database.MCPServerConfig, tokensByConfigID map[uuid.UUID]database.MCPServerUserToken, + userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue, userID uuid.UUID, oidcSrc UserOIDCTokenSource, ) map[string]string { @@ -381,6 +399,43 @@ func buildAuthHeaders( } } } + // Overlay user-supplied values for keys the admin marked as + // user-set. Validation guarantees these are disjoint from + // CustomHeaders, but the overlay is well-defined either way. + if len(cfg.CustomHeadersUserKeys) > 0 { + row, ok := userHeaderValuesByConfigID[cfg.ID] + if !ok { + // Normal state: this user has never saved values for + // this server. The MCP call will proceed without the + // user-set headers and likely fail at the remote end, + // which is the expected signal for the UI to prompt + // the user. Debug-level keeps this off the noise floor. + logger.Debug(ctx, + "no user header values for MCP server", + slog.F("server_slug", cfg.Slug), + ) + break + } + var user map[string]string + if err := json.Unmarshal( + []byte(row.HeaderValues), &user, + ); err != nil { + logger.Warn(ctx, + "failed to parse user header values JSON", + slog.F("server_slug", cfg.Slug), + slog.Error(err), + ) + break + } + for _, k := range cfg.CustomHeadersUserKeys { + // Case-insensitive lookup so a case-only admin rename + // does not silently drop the user's stored value. + v, has := mcpHeaderValueForKey(user, k) + if has && v != "" { + headers[k] = v + } + } + } case "user_oidc": // Forward the calling user's OIDC access token from // user_links as Authorization: Bearer . The token @@ -422,6 +477,23 @@ func buildAuthHeaders( return headers } +// mcpHeaderValueForKey returns the stored value for key using a +// case-insensitive match. The stored user-header map preserves the +// admin's casing at write time, so a later case-only rename of a +// user-set key would otherwise orphan the stored value until the +// user re-saves it. +func mcpHeaderValueForKey(stored map[string]string, key string) (string, bool) { + if v, ok := stored[key]; ok { + return v, true + } + for k, v := range stored { + if strings.EqualFold(k, key) { + return v, true + } + } + return "", false +} + // isToolAllowed checks a tool name against the allow and deny // lists. When the allow list is non-empty only tools in it are // permitted and the deny list is ignored. When the allow list diff --git a/coderd/x/chatd/mcpclient/mcpclient_test.go b/coderd/x/chatd/mcpclient/mcpclient_test.go index d91788fd2f..c873b3776a 100644 --- a/coderd/x/chatd/mcpclient/mcpclient_test.go +++ b/coderd/x/chatd/mcpclient/mcpclient_test.go @@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) { ts := newTestMCPServer(t, echoTool(), greetTool()) cfg := makeConfig("myserver", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) // Two tools should be discovered, namespaced with the server slug. @@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) { // Only allow the "echo" tool. cfg.ToolAllowList = []string{"echo"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) { // Deny the "greet" tool, so only "echo" remains. cfg.ToolDenyList = []string{"greet"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) { cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist") - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools, "no tools should be returned for an unreachable server") @@ -200,6 +200,7 @@ func TestConnectAll_MultipleServers(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + nil, uuid.Nil, nil, nil, ) @@ -227,6 +228,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) { logger, []database.MCPServerConfig{cfg}, nil, + nil, uuid.Nil, nil, nil, ) @@ -256,6 +258,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { makeConfig("srv2", ts2.URL), }, nil, + nil, uuid.Nil, nil, nil, ) @@ -286,6 +289,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { makeConfig("aaa", other.URL), }, nil, + nil, uuid.Nil, nil, nil, ) @@ -320,6 +324,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + nil, uuid.Nil, nil, nil, ) @@ -385,6 +390,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, + nil, uuid.Nil, nil, nil, ) @@ -441,7 +447,7 @@ func TestConnectAll_DisabledServer(t *testing.T) { cfg := makeConfig("disabled", ts.URL) cfg.Enabled = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools) } @@ -456,7 +462,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -481,7 +487,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -523,7 +529,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) { ts := newTestMCPServer(t, noRequiredTool) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -575,6 +581,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + nil, uuid.Nil, nil, nil, ) @@ -633,6 +640,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + nil, uuid.Nil, nil, nil, ) @@ -671,6 +679,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + nil, uuid.Nil, nil, nil, ) @@ -682,6 +691,134 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) { assert.Equal(t, "badjson__echo", tools[0].Info().Name) } +// TestConnectAll_CustomHeadersUserKeysOverlay verifies that +// custom_headers auth overlays per-user values onto the admin-set +// headers based on cfg.CustomHeadersUserKeys. +func TestConnectAll_CustomHeadersUserKeysOverlay(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + var ( + mu sync.Mutex + seenOrg []string + seenUser []string + seenWS []string + ) + + srv := mcpserver.NewMCPServer("overlay-server", "1.0.0") + srv.AddTools(mcpserver.ServerTool{ + Tool: mcp.NewTool("check", + mcp.WithDescription("Returns the auth headers"), + ), + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mu.Lock() + seenOrg = append(seenOrg, req.Header.Get("X-Org-ID")) + seenUser = append(seenUser, req.Header.Get("X-User-Token")) + seenWS = append(seenWS, req.Header.Get("X-Workspace")) + mu.Unlock() + return mcp.NewToolResultText("ok"), nil + }, + }) + httpSrv := mcpserver.NewStreamableHTTPServer(srv) + ts := httptest.NewServer(httpSrv) + t.Cleanup(ts.Close) + + cfg := makeConfig("overlay", ts.URL) + cfg.AuthType = "custom_headers" + cfg.CustomHeaders = `{"X-Org-ID":"acme"}` + cfg.CustomHeadersUserKeys = []string{"X-User-Token", "X-Workspace"} + + userHeaderValues := []database.McpServerUserHeaderValue{{ + MCPServerConfigID: cfg.ID, + UserID: uuid.New(), + HeaderValues: `{"X-User-Token":"jwt-abc","X-Workspace":"main"}`, + }} + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, + userHeaderValues, + uuid.Nil, nil, + nil, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + resp, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-overlay", + Name: "overlay__check", + Input: "{}", + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, seenOrg) + assert.Equal(t, "acme", seenOrg[len(seenOrg)-1], "admin header preserved") + assert.Equal(t, "jwt-abc", seenUser[len(seenUser)-1], "user-set header overlays") + assert.Equal(t, "main", seenWS[len(seenWS)-1], "user-set header overlays") +} + +// TestConnectAll_CustomHeadersUserKeysMissingRow verifies that when +// CustomHeadersUserKeys is non-empty but no user header values row is +// present, admin headers still go out and user keys are simply absent. +func TestConnectAll_CustomHeadersUserKeysMissingRow(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + var ( + mu sync.Mutex + seenOrg []string + seenUser []string + ) + srv := mcpserver.NewMCPServer("missing-server", "1.0.0") + srv.AddTools(mcpserver.ServerTool{ + Tool: mcp.NewTool("check", + mcp.WithDescription("Returns the auth headers"), + ), + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mu.Lock() + seenOrg = append(seenOrg, req.Header.Get("X-Org-ID")) + seenUser = append(seenUser, req.Header.Get("X-User-Token")) + mu.Unlock() + return mcp.NewToolResultText("ok"), nil + }, + }) + httpSrv := mcpserver.NewStreamableHTTPServer(srv) + ts := httptest.NewServer(httpSrv) + t.Cleanup(ts.Close) + + cfg := makeConfig("missing", ts.URL) + cfg.AuthType = "custom_headers" + cfg.CustomHeaders = `{"X-Org-ID":"acme"}` + cfg.CustomHeadersUserKeys = []string{"X-User-Token"} + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, + nil, // no userHeaderValues at all + uuid.Nil, nil, + nil, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + resp, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-missing", + Name: "missing__check", + Input: "{}", + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, seenOrg) + assert.Equal(t, "acme", seenOrg[len(seenOrg)-1]) + assert.Equal(t, "", seenUser[len(seenUser)-1]) +} + // staticOIDCSource implements mcpclient.UserOIDCTokenSource for tests // without requiring a real OIDC provider or database round-trip. type staticOIDCSource struct { @@ -730,7 +867,7 @@ func TestConnectAll_UserOIDCAuth(t *testing.T) { src := staticOIDCSource{token: "fake-oidc-token"} tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, userID, src, nil, ) t.Cleanup(cleanup) @@ -789,7 +926,7 @@ func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) { src := staticOIDCSource{token: "", err: nil} tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.New(), src, nil, ) t.Cleanup(cleanup) @@ -825,7 +962,7 @@ func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) { cfg.AuthType = "user_oidc" tools, cleanup := mcpclient.ConnectAll( - ctx, logger, []database.MCPServerConfig{cfg}, nil, + ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.New(), nil, nil, ) t.Cleanup(cleanup) @@ -854,6 +991,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2, cfg3}, nil, + nil, uuid.Nil, nil, nil, ) @@ -916,7 +1054,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) { Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) // The server accepts any auth, so the tool is still discovered @@ -949,7 +1087,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) { TokenType: "Bearer", } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) // Tool is still discovered (server doesn't require auth), but @@ -979,7 +1117,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) { Enabled: true, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1025,6 +1163,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + nil, uuid.Nil, nil, nil, ) @@ -1083,7 +1222,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-txt", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1150,7 +1289,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-blob", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1230,7 +1369,7 @@ func TestConnectAll_ResourceLink(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("res-link", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1274,7 +1413,7 @@ func TestConnectAll_CallToolError(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("err-srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1298,7 +1437,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) { cfg := makeConfig("intent-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1334,7 +1473,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) { cfg := makeConfig("no-intent", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1357,7 +1496,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) { cfg := makeConfig("unwrap-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1382,7 +1521,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) { cfg := makeConfig("flat-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1407,7 +1546,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) { cfg := makeConfig("pass-srv", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1432,7 +1571,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) { cfg := makeConfig("bad-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) diff --git a/coderd/x/chatd/mcpclient/mcphttpclient.go b/coderd/x/chatd/mcpclient/mcphttpclient.go deleted file mode 100644 index c34ff59262..0000000000 --- a/coderd/x/chatd/mcpclient/mcphttpclient.go +++ /dev/null @@ -1,25 +0,0 @@ -package mcpclient - -import ( - "net/http" - "testing" -) - -// mcpHTTPClient returns an isolated *http.Client when running -// inside tests, or nil for production. During tests, -// httptest.Server.Close() calls -// http.DefaultTransport.CloseIdleConnections(), which disrupts -// any MCP client sharing that transport. When DefaultTransport -// is a *http.Transport it is cloned; otherwise a minimal -// transport with ProxyFromEnvironment is created as a fallback. -func mcpHTTPClient() *http.Client { - if !testing.Testing() { - return nil - } - if dt, ok := http.DefaultTransport.(*http.Transport); ok { - return &http.Client{Transport: dt.Clone()} - } - return &http.Client{Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - }} -}