Files
coder/coderd/x/chatd/mcpclient/coder_headers_test.go
T
Steven Masley d2f9ad783e feat(coderd/x/chatd): overlay user-set custom_headers at runtime
Threads the per-user custom_headers values stored on
mcp_server_user_header_values through the chatd MCP client so users
who provided a value for an admin-marked CustomHeadersUserKey see it
mixed into the outgoing request alongside the admin-static headers.

Changes:

- mcpclient.ConnectAll grows a fourth indexed input,
  []database.McpServerUserHeaderValue, which buildAuthHeaders
  consults inside the custom_headers branch to overlay per-user
  values on top of admin static headers, scoped to
  cfg.CustomHeadersUserKeys.
- chatd loads the user's stored header values via
  GetMCPServerUserHeaderValuesByUserID alongside the existing
  GetMCPServerUserTokensByUserID call and threads them into
  ConnectAll. A missing row is non-fatal: admin headers still
  ship, user-keyed headers are simply absent and a warning is
  logged.
- mcpclient.go inlines its own DefaultTransport clone for test
  isolation, replacing the standalone helper in mcphttpclient.go,
  which is removed.

Stack: 4/6 (chatd runtime overlay)
2026-06-01 15:02:34 +00:00

331 lines
10 KiB
Go

package mcpclient_test
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
)
// newHeaderRecordingServer creates a streamable HTTP MCP server with a
// single "ping" tool. Every request's headers are appended to the
// returned slice so tests can assert which headers were forwarded.
func newHeaderRecordingServer(t *testing.T) (*httptest.Server, *sync.Mutex, *[]http.Header) {
t.Helper()
var (
mu sync.Mutex
headers []http.Header
)
srv := mcpserver.NewMCPServer("hdr-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("ping", mcp.WithDescription("records the request headers")),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
mu.Lock()
headers = append(headers, req.Header.Clone())
mu.Unlock()
return mcp.NewToolResultText("ok"), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
return ts, &mu, &headers
}
// TestConnectAll_ForwardCoderHeaders_DefaultOff is a regression guard
// that the Coder identity headers are NOT sent when the option is
// left at its default (false).
func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
cfg := makeConfig("no-hdr", ts.URL)
assert.False(t, cfg.ForwardCoderHeaders, "default must be false")
coderHeaders := map[string]string{
chatprovider.HeaderCoderOwnerID: uuid.NewString(),
chatprovider.HeaderCoderChatID: uuid.NewString(),
chatprovider.HeaderCoderWorkspaceID: uuid.NewString(),
}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "no-hdr__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
for _, h := range *recorded {
assert.Empty(t, h.Get(chatprovider.HeaderCoderOwnerID))
assert.Empty(t, h.Get(chatprovider.HeaderCoderChatID))
assert.Empty(t, h.Get(chatprovider.HeaderCoderSubchatID))
assert.Empty(t, h.Get(chatprovider.HeaderCoderWorkspaceID))
}
}
// TestConnectAll_ForwardCoderHeaders_Enabled verifies that when the
// option is enabled, the Coder identity headers are forwarded on every
// outgoing MCP request, including the subchat and workspace headers.
func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
ownerID := uuid.New()
chatID := uuid.New()
workspaceID := uuid.New()
subchatID := uuid.New()
cfg := makeConfig("hdr", ts.URL)
cfg.ForwardCoderHeaders = true
// Subchat headers: parent's chat ID lives in X-Coder-Chat-Id, the
// subchat's own ID lives in X-Coder-Subchat-Id.
coderHeaders := chatprovider.CoderHeaders(database.Chat{
ID: subchatID,
OwnerID: ownerID,
ParentChatID: uuid.NullUUID{UUID: chatID, Valid: true},
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "hdr__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
last := (*recorded)[len(*recorded)-1]
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
assert.Equal(t, subchatID.String(), last.Get(chatprovider.HeaderCoderSubchatID))
assert.Equal(t, workspaceID.String(), last.Get(chatprovider.HeaderCoderWorkspaceID))
}
// TestConnectAll_ForwardCoderHeaders_RootChat verifies that for a root
// chat (no parent), the chat's own ID is forwarded as
// X-Coder-Chat-Id and the X-Coder-Subchat-Id header is absent.
func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
ownerID := uuid.New()
chatID := uuid.New()
cfg := makeConfig("hdr-root", ts.URL)
cfg.ForwardCoderHeaders = true
coderHeaders := chatprovider.CoderHeaders(database.Chat{
ID: chatID,
OwnerID: ownerID,
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "hdr-root__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
last := (*recorded)[len(*recorded)-1]
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
assert.Empty(t, last.Get(chatprovider.HeaderCoderSubchatID))
assert.Empty(t, last.Get(chatprovider.HeaderCoderWorkspaceID))
}
// TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth verifies that the
// api_key auth header is preserved when Coder identity headers are
// forwarded alongside.
func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
ownerID := uuid.New()
chatID := uuid.New()
cfg := makeConfig("hdr-apikey", ts.URL)
cfg.AuthType = "api_key"
cfg.APIKeyHeader = "X-Api-Key"
cfg.APIKeyValue = "sekret"
cfg.ForwardCoderHeaders = true
coderHeaders := chatprovider.CoderHeaders(database.Chat{
ID: chatID,
OwnerID: ownerID,
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "hdr-apikey__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
last := (*recorded)[len(*recorded)-1]
assert.Equal(t, "sekret", last.Get("X-Api-Key"))
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
}
// TestConnectAll_ForwardCoderHeaders_WithOAuth2 verifies that the
// oauth2 Authorization header is preserved when Coder identity
// headers are forwarded alongside, and that auth wins on a conflict.
func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
cfgID := uuid.New()
cfg := makeConfig("hdr-oauth", ts.URL)
cfg.ID = cfgID
cfg.AuthType = "oauth2"
cfg.ForwardCoderHeaders = true
token := database.MCPServerUserToken{
MCPServerConfigID: cfgID,
AccessToken: "oauth-token-xyz",
TokenType: "Bearer",
}
// Intentionally include an Authorization key to verify the auth
// header wins on conflict.
ownerID := uuid.NewString()
coderHeaders := map[string]string{
"Authorization": "Bearer should-be-overridden",
chatprovider.HeaderCoderOwnerID: ownerID,
}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger,
[]database.MCPServerConfig{cfg},
[]database.MCPServerUserToken{token},
nil,
uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "hdr-oauth__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
last := (*recorded)[len(*recorded)-1]
assert.Equal(t, "Bearer oauth-token-xyz", last.Get("Authorization"))
assert.Equal(t, ownerID, last.Get(chatprovider.HeaderCoderOwnerID))
}
// TestConnectAll_ForwardCoderHeaders_WithCustomHeaders verifies that
// custom_headers admin-configured values are preserved when Coder
// identity headers are forwarded alongside, including the case where
// the admin configures a custom header whose name only differs from a
// Coder identity header by case. Conflict detection is case-
// insensitive because http.Header.Set canonicalizes header names.
func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) {
t.Parallel()
ctx := t.Context()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts, mu, recorded := newHeaderRecordingServer(t)
ownerID := uuid.New()
chatID := uuid.New()
cfg := makeConfig("hdr-custom", ts.URL)
cfg.AuthType = "custom_headers"
// Include both an unrelated custom header AND a case-variant of
// X-Coder-Owner-Id to exercise the case-insensitive conflict
// check. The admin-configured value MUST win.
cfg.CustomHeaders = `{"X-Tenant":"acme","x-coder-owner-id":"admin-controlled"}`
cfg.ForwardCoderHeaders = true
coderHeaders := chatprovider.CoderHeaders(database.Chat{
ID: chatID,
OwnerID: ownerID,
})
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
coderHeaders,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
_, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-1", Name: "hdr-custom__ping", Input: "{}",
})
require.NoError(t, err)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, *recorded)
last := (*recorded)[len(*recorded)-1]
assert.Equal(t, "acme", last.Get("X-Tenant"))
// The admin's case-variant header must win, because HTTP header
// names are case-insensitive at the transport level.
assert.Equal(t, "admin-controlled", last.Get(chatprovider.HeaderCoderOwnerID))
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
}