mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(coderd/x/chatd): overlay user-set custom_headers at runtime
Threads the per-user custom_headers values stored on mcp_server_user_header_values through the chatd MCP client so users who provided a value for an admin-marked CustomHeadersUserKey see it mixed into the outgoing request alongside the admin-static headers. Changes: - mcpclient.ConnectAll grows a fourth indexed input, []database.McpServerUserHeaderValue, which buildAuthHeaders consults inside the custom_headers branch to overlay per-user values on top of admin static headers, scoped to cfg.CustomHeadersUserKeys. - chatd loads the user's stored header values via GetMCPServerUserHeaderValuesByUserID alongside the existing GetMCPServerUserTokensByUserID call and threads them into ConnectAll. A missing row is non-fatal: admin headers still ship, user-keyed headers are simply absent and a warning is logged. - mcpclient.go inlines its own DefaultTransport clone for test isolation, replacing the standalone helper in mcphttpclient.go, which is removed. Stack: 4/6 (chatd runtime overlay)
This commit is contained in:
+20
-3
@@ -7130,8 +7130,9 @@ func (p *Server) runChat(
|
||||
// resolution. These queries have no dependencies on each other and all
|
||||
// hit different tables.
|
||||
var (
|
||||
mcpConfigs []database.MCPServerConfig
|
||||
mcpTokens []database.MCPServerUserToken
|
||||
mcpConfigs []database.MCPServerConfig
|
||||
mcpTokens []database.MCPServerUserToken
|
||||
mcpHeaderValues []database.McpServerUserHeaderValue
|
||||
)
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
@@ -7179,6 +7180,22 @@ func (p *Server) runChat(
|
||||
}
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
// If header-values loading fails, ConnectAll proceeds
|
||||
// without user values; custom_headers servers that
|
||||
// require user-set keys will be missing those headers.
|
||||
mcpHeaderValues, err = p.db.GetMCPServerUserHeaderValuesByUserID(
|
||||
ctx, chat.OwnerID,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx,
|
||||
"failed to load MCP user header values",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return result, err
|
||||
@@ -7493,7 +7510,7 @@ func (p *Server) runChat(
|
||||
// Refresh expired OAuth2 tokens before connecting.
|
||||
mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens)
|
||||
mcpTools, mcpCleanup = mcpclient.ConnectAll(
|
||||
ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource,
|
||||
ctx, logger, mcpConnectConfigs, mcpTokens, mcpHeaderValues, chat.OwnerID, p.oidcTokenSource,
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
return nil
|
||||
|
||||
@@ -65,7 +65,7 @@ func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) {
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -115,7 +115,7 @@ func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) {
|
||||
})
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -158,7 +158,7 @@ func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) {
|
||||
})
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -204,7 +204,7 @@ func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) {
|
||||
})
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -257,6 +257,7 @@ func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
[]database.MCPServerUserToken{token},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
@@ -306,7 +307,7 @@ func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) {
|
||||
})
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
||||
coderHeaders,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -72,6 +72,7 @@ func ConnectAll(
|
||||
logger slog.Logger,
|
||||
configs []database.MCPServerConfig,
|
||||
tokens []database.MCPServerUserToken,
|
||||
userHeaderValues []database.McpServerUserHeaderValue,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
coderHeaders map[string]string,
|
||||
@@ -85,6 +86,14 @@ func ConnectAll(
|
||||
tokensByConfigID[tok.MCPServerConfigID] = tok
|
||||
}
|
||||
|
||||
// Same indexing for the calling user's custom_headers values.
|
||||
userHeaderValuesByConfigID := make(
|
||||
map[uuid.UUID]database.McpServerUserHeaderValue, len(userHeaderValues),
|
||||
)
|
||||
for _, hv := range userHeaderValues {
|
||||
userHeaderValuesByConfigID[hv.MCPServerConfigID] = hv
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
clients []*client.Client
|
||||
@@ -110,7 +119,7 @@ func ConnectAll(
|
||||
|
||||
eg.Go(func() error {
|
||||
serverTools, mcpClient, connectErr := connectOne(
|
||||
ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, coderHeaders,
|
||||
ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc, coderHeaders,
|
||||
)
|
||||
if connectErr != nil {
|
||||
logger.Warn(ctx,
|
||||
@@ -174,11 +183,12 @@ func connectOne(
|
||||
logger slog.Logger,
|
||||
cfg database.MCPServerConfig,
|
||||
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
||||
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
coderHeaders map[string]string,
|
||||
) ([]fantasy.AgentTool, *client.Client, error) {
|
||||
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc)
|
||||
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userHeaderValuesByConfigID, userID, oidcSrc)
|
||||
|
||||
// When opted-in, merge Coder identity headers BEFORE the
|
||||
// transport is created so any auth header already set above
|
||||
@@ -285,24 +295,31 @@ func createTransport(
|
||||
cfg database.MCPServerConfig,
|
||||
headers map[string]string,
|
||||
) (transport.Interface, error) {
|
||||
httpClient := mcpHTTPClient()
|
||||
// Each connection gets its own HTTP client with a dedicated
|
||||
// transport so that httptest.Server.Close() (which calls
|
||||
// CloseIdleConnections on http.DefaultTransport) does not
|
||||
// disrupt unrelated connections during parallel tests.
|
||||
var httpClient *http.Client
|
||||
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
httpClient = &http.Client{Transport: dt.Clone()}
|
||||
} else {
|
||||
httpClient = &http.Client{}
|
||||
}
|
||||
|
||||
switch cfg.Transport {
|
||||
case "sse":
|
||||
var opts []transport.ClientOption
|
||||
opts = append(opts, transport.WithHeaders(headers))
|
||||
if httpClient != nil {
|
||||
opts = append(opts, transport.WithHTTPClient(httpClient))
|
||||
}
|
||||
return transport.NewSSE(cfg.Url, opts...)
|
||||
return transport.NewSSE(
|
||||
cfg.Url,
|
||||
transport.WithHeaders(headers),
|
||||
transport.WithHTTPClient(httpClient),
|
||||
)
|
||||
case "", "streamable_http":
|
||||
// Default to streamable HTTP, the newer transport.
|
||||
var opts []transport.StreamableHTTPCOption
|
||||
opts = append(opts, transport.WithHTTPHeaders(headers))
|
||||
if httpClient != nil {
|
||||
opts = append(opts, transport.WithHTTPBasicClient(httpClient))
|
||||
}
|
||||
return transport.NewStreamableHTTP(cfg.Url, opts...)
|
||||
return transport.NewStreamableHTTP(
|
||||
cfg.Url,
|
||||
transport.WithHTTPHeaders(headers),
|
||||
transport.WithHTTPBasicClient(httpClient),
|
||||
)
|
||||
default:
|
||||
return nil, xerrors.Errorf(
|
||||
"unsupported transport %q", cfg.Transport,
|
||||
@@ -317,6 +334,7 @@ func buildAuthHeaders(
|
||||
logger slog.Logger,
|
||||
cfg database.MCPServerConfig,
|
||||
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
||||
userHeaderValuesByConfigID map[uuid.UUID]database.McpServerUserHeaderValue,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
) map[string]string {
|
||||
@@ -381,6 +399,43 @@ func buildAuthHeaders(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Overlay user-supplied values for keys the admin marked as
|
||||
// user-set. Validation guarantees these are disjoint from
|
||||
// CustomHeaders, but the overlay is well-defined either way.
|
||||
if len(cfg.CustomHeadersUserKeys) > 0 {
|
||||
row, ok := userHeaderValuesByConfigID[cfg.ID]
|
||||
if !ok {
|
||||
// Normal state: this user has never saved values for
|
||||
// this server. The MCP call will proceed without the
|
||||
// user-set headers and likely fail at the remote end,
|
||||
// which is the expected signal for the UI to prompt
|
||||
// the user. Debug-level keeps this off the noise floor.
|
||||
logger.Debug(ctx,
|
||||
"no user header values for MCP server",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
)
|
||||
break
|
||||
}
|
||||
var user map[string]string
|
||||
if err := json.Unmarshal(
|
||||
[]byte(row.HeaderValues), &user,
|
||||
); err != nil {
|
||||
logger.Warn(ctx,
|
||||
"failed to parse user header values JSON",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
slog.Error(err),
|
||||
)
|
||||
break
|
||||
}
|
||||
for _, k := range cfg.CustomHeadersUserKeys {
|
||||
// Case-insensitive lookup so a case-only admin rename
|
||||
// does not silently drop the user's stored value.
|
||||
v, has := mcpHeaderValueForKey(user, k)
|
||||
if has && v != "" {
|
||||
headers[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
case "user_oidc":
|
||||
// Forward the calling user's OIDC access token from
|
||||
// user_links as Authorization: Bearer <token>. The token
|
||||
@@ -422,6 +477,23 @@ func buildAuthHeaders(
|
||||
return headers
|
||||
}
|
||||
|
||||
// mcpHeaderValueForKey returns the stored value for key using a
|
||||
// case-insensitive match. The stored user-header map preserves the
|
||||
// admin's casing at write time, so a later case-only rename of a
|
||||
// user-set key would otherwise orphan the stored value until the
|
||||
// user re-saves it.
|
||||
func mcpHeaderValueForKey(stored map[string]string, key string) (string, bool) {
|
||||
if v, ok := stored[key]; ok {
|
||||
return v, true
|
||||
}
|
||||
for k, v := range stored {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// isToolAllowed checks a tool name against the allow and deny
|
||||
// lists. When the allow list is non-empty only tools in it are
|
||||
// permitted and the deny list is ignored. When the allow list
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool(), greetTool())
|
||||
|
||||
cfg := makeConfig("myserver", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Two tools should be discovered, namespaced with the server slug.
|
||||
@@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) {
|
||||
// Only allow the "echo" tool.
|
||||
cfg.ToolAllowList = []string{"echo"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) {
|
||||
// Deny the "greet" tool, so only "echo" remains.
|
||||
cfg.ToolDenyList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) {
|
||||
|
||||
cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assert.Empty(t, tools, "no tools should be returned for an unreachable server")
|
||||
@@ -200,6 +200,7 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -227,6 +228,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -256,6 +258,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -286,6 +289,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -320,6 +324,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -385,6 +390,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
[]database.MCPServerUserToken{token},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -441,7 +447,7 @@ func TestConnectAll_DisabledServer(t *testing.T) {
|
||||
cfg := makeConfig("disabled", ts.URL)
|
||||
cfg.Enabled = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
assert.Empty(t, tools)
|
||||
}
|
||||
@@ -456,7 +462,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -481,7 +487,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -523,7 +529,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) {
|
||||
|
||||
ts := newTestMCPServer(t, noRequiredTool)
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -575,6 +581,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -633,6 +640,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -671,6 +679,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -682,6 +691,134 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
|
||||
assert.Equal(t, "badjson__echo", tools[0].Info().Name)
|
||||
}
|
||||
|
||||
// TestConnectAll_CustomHeadersUserKeysOverlay verifies that
|
||||
// custom_headers auth overlays per-user values onto the admin-set
|
||||
// headers based on cfg.CustomHeadersUserKeys.
|
||||
func TestConnectAll_CustomHeadersUserKeysOverlay(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seenOrg []string
|
||||
seenUser []string
|
||||
seenWS []string
|
||||
)
|
||||
|
||||
srv := mcpserver.NewMCPServer("overlay-server", "1.0.0")
|
||||
srv.AddTools(mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool("check",
|
||||
mcp.WithDescription("Returns the auth headers"),
|
||||
),
|
||||
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
mu.Lock()
|
||||
seenOrg = append(seenOrg, req.Header.Get("X-Org-ID"))
|
||||
seenUser = append(seenUser, req.Header.Get("X-User-Token"))
|
||||
seenWS = append(seenWS, req.Header.Get("X-Workspace"))
|
||||
mu.Unlock()
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
})
|
||||
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
|
||||
ts := httptest.NewServer(httpSrv)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("overlay", ts.URL)
|
||||
cfg.AuthType = "custom_headers"
|
||||
cfg.CustomHeaders = `{"X-Org-ID":"acme"}`
|
||||
cfg.CustomHeadersUserKeys = []string{"X-User-Token", "X-Workspace"}
|
||||
|
||||
userHeaderValues := []database.McpServerUserHeaderValue{{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: uuid.New(),
|
||||
HeaderValues: `{"X-User-Token":"jwt-abc","X-Workspace":"main"}`,
|
||||
}}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
userHeaderValues,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-overlay",
|
||||
Name: "overlay__check",
|
||||
Input: "{}",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.NotEmpty(t, seenOrg)
|
||||
assert.Equal(t, "acme", seenOrg[len(seenOrg)-1], "admin header preserved")
|
||||
assert.Equal(t, "jwt-abc", seenUser[len(seenUser)-1], "user-set header overlays")
|
||||
assert.Equal(t, "main", seenWS[len(seenWS)-1], "user-set header overlays")
|
||||
}
|
||||
|
||||
// TestConnectAll_CustomHeadersUserKeysMissingRow verifies that when
|
||||
// CustomHeadersUserKeys is non-empty but no user header values row is
|
||||
// present, admin headers still go out and user keys are simply absent.
|
||||
func TestConnectAll_CustomHeadersUserKeysMissingRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seenOrg []string
|
||||
seenUser []string
|
||||
)
|
||||
srv := mcpserver.NewMCPServer("missing-server", "1.0.0")
|
||||
srv.AddTools(mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool("check",
|
||||
mcp.WithDescription("Returns the auth headers"),
|
||||
),
|
||||
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
mu.Lock()
|
||||
seenOrg = append(seenOrg, req.Header.Get("X-Org-ID"))
|
||||
seenUser = append(seenUser, req.Header.Get("X-User-Token"))
|
||||
mu.Unlock()
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
})
|
||||
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
|
||||
ts := httptest.NewServer(httpSrv)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("missing", ts.URL)
|
||||
cfg.AuthType = "custom_headers"
|
||||
cfg.CustomHeaders = `{"X-Org-ID":"acme"}`
|
||||
cfg.CustomHeadersUserKeys = []string{"X-User-Token"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
nil, // no userHeaderValues at all
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-missing",
|
||||
Name: "missing__check",
|
||||
Input: "{}",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.NotEmpty(t, seenOrg)
|
||||
assert.Equal(t, "acme", seenOrg[len(seenOrg)-1])
|
||||
assert.Equal(t, "", seenUser[len(seenUser)-1])
|
||||
}
|
||||
|
||||
// staticOIDCSource implements mcpclient.UserOIDCTokenSource for tests
|
||||
// without requiring a real OIDC provider or database round-trip.
|
||||
type staticOIDCSource struct {
|
||||
@@ -730,7 +867,7 @@ func TestConnectAll_UserOIDCAuth(t *testing.T) {
|
||||
src := staticOIDCSource{token: "fake-oidc-token"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
|
||||
userID, src, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -789,7 +926,7 @@ func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) {
|
||||
src := staticOIDCSource{token: "", err: nil}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
|
||||
uuid.New(), src, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -825,7 +962,7 @@ func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) {
|
||||
cfg.AuthType = "user_oidc"
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil,
|
||||
uuid.New(), nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
@@ -854,6 +991,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2, cfg3},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -916,7 +1054,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) {
|
||||
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// The server accepts any auth, so the tool is still discovered
|
||||
@@ -949,7 +1087,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) {
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Tool is still discovered (server doesn't require auth), but
|
||||
@@ -979,7 +1117,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) {
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -1025,6 +1163,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
nil,
|
||||
)
|
||||
@@ -1083,7 +1222,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("embed-txt", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1150,7 +1289,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("embed-blob", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1230,7 +1369,7 @@ func TestConnectAll_ResourceLink(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("res-link", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1274,7 +1413,7 @@ func TestConnectAll_CallToolError(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("err-srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1298,7 +1437,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) {
|
||||
cfg := makeConfig("intent-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1334,7 +1473,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) {
|
||||
cfg := makeConfig("no-intent", ts.URL)
|
||||
cfg.ModelIntent = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1357,7 +1496,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) {
|
||||
cfg := makeConfig("unwrap-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1382,7 +1521,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) {
|
||||
cfg := makeConfig("flat-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1407,7 +1546,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) {
|
||||
cfg := makeConfig("pass-srv", ts.URL)
|
||||
cfg.ModelIntent = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1432,7 +1571,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) {
|
||||
cfg := makeConfig("bad-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package mcpclient
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mcpHTTPClient returns an isolated *http.Client when running
|
||||
// inside tests, or nil for production. During tests,
|
||||
// httptest.Server.Close() calls
|
||||
// http.DefaultTransport.CloseIdleConnections(), which disrupts
|
||||
// any MCP client sharing that transport. When DefaultTransport
|
||||
// is a *http.Transport it is cloned; otherwise a minimal
|
||||
// transport with ProxyFromEnvironment is created as a fallback.
|
||||
func mcpHTTPClient() *http.Client {
|
||||
if !testing.Testing() {
|
||||
return nil
|
||||
}
|
||||
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
return &http.Client{Transport: dt.Clone()}
|
||||
}
|
||||
return &http.Client{Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}}
|
||||
}
|
||||
Reference in New Issue
Block a user