diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index e50edfbfea..c4ae990f93 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -229,6 +229,7 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon Enabled: takeFirst(seed.Enabled, true), ModelIntent: seed.ModelIntent, AllowInPlanMode: seed.AllowInPlanMode, + ForwardCoderHeaders: seed.ForwardCoderHeaders, CreatedBy: createdBy, UpdatedBy: updatedBy, }) diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 1ed77d9c0b..4c9ee415cc 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1800,6 +1800,7 @@ CREATE TABLE mcp_server_configs ( updated_at timestamp with time zone DEFAULT now() NOT NULL, model_intent boolean DEFAULT false NOT NULL, allow_in_plan_mode boolean DEFAULT false NOT NULL, + forward_coder_headers boolean DEFAULT false NOT NULL, CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text, 'user_oidc'::text]))), CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))), CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text]))) diff --git a/coderd/database/migrations/000491_mcp_server_forward_coder_headers.down.sql b/coderd/database/migrations/000491_mcp_server_forward_coder_headers.down.sql new file mode 100644 index 0000000000..e4ef51bfc4 --- /dev/null +++ b/coderd/database/migrations/000491_mcp_server_forward_coder_headers.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE mcp_server_configs + DROP COLUMN forward_coder_headers; diff --git a/coderd/database/migrations/000491_mcp_server_forward_coder_headers.up.sql b/coderd/database/migrations/000491_mcp_server_forward_coder_headers.up.sql new file mode 100644 index 0000000000..dfa63fc936 --- /dev/null +++ b/coderd/database/migrations/000491_mcp_server_forward_coder_headers.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE mcp_server_configs + ADD COLUMN forward_coder_headers BOOLEAN NOT NULL DEFAULT false; diff --git a/coderd/database/migrations/testdata/fixtures/000491_mcp_server_forward_coder_headers.up.sql b/coderd/database/migrations/testdata/fixtures/000491_mcp_server_forward_coder_headers.up.sql new file mode 100644 index 0000000000..33aba5897b --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000491_mcp_server_forward_coder_headers.up.sql @@ -0,0 +1,6 @@ +-- Migration 491 adds forward_coder_headers with a default of false. +-- Flip the existing fixture row to true here so fixture data exercises +-- the non-default state only after the column exists. +UPDATE mcp_server_configs +SET forward_coder_headers = TRUE +WHERE id = 'a1b2c3d4-e5f6-7890-abcd-ef1234567890'; diff --git a/coderd/database/models.go b/coderd/database/models.go index 143a97a15a..6c90c0499a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4765,6 +4765,7 @@ type MCPServerConfig struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` ModelIntent bool `db:"model_intent" json:"model_intent"` AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` } type MCPServerUserToken struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 913231b402..fc6df826cd 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -13209,7 +13209,7 @@ func (q *sqlQuerier) DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCP const getEnabledMCPServerConfigs = `-- name: GetEnabledMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs WHERE @@ -13257,6 +13257,7 @@ func (q *sqlQuerier) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServe &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ); err != nil { return nil, err } @@ -13273,7 +13274,7 @@ func (q *sqlQuerier) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServe const getForcedMCPServerConfigs = `-- name: GetForcedMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs WHERE @@ -13322,6 +13323,7 @@ func (q *sqlQuerier) GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServer &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ); err != nil { return nil, err } @@ -13338,7 +13340,7 @@ func (q *sqlQuerier) GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServer const getMCPServerConfigByID = `-- name: GetMCPServerConfigByID :one SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs WHERE @@ -13378,13 +13380,14 @@ func (q *sqlQuerier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) ( &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ) return i, err } const getMCPServerConfigBySlug = `-- name: GetMCPServerConfigBySlug :one SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs WHERE @@ -13424,13 +13427,14 @@ func (q *sqlQuerier) GetMCPServerConfigBySlug(ctx context.Context, slug string) &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ) return i, err } const getMCPServerConfigs = `-- name: GetMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs ORDER BY @@ -13476,6 +13480,7 @@ func (q *sqlQuerier) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ); err != nil { return nil, err } @@ -13492,7 +13497,7 @@ func (q *sqlQuerier) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig const getMCPServerConfigsByIDs = `-- name: GetMCPServerConfigsByIDs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers FROM mcp_server_configs WHERE @@ -13540,6 +13545,7 @@ func (q *sqlQuerier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UU &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ); err != nil { return nil, err } @@ -13658,6 +13664,7 @@ INSERT INTO mcp_server_configs ( enabled, model_intent, allow_in_plan_mode, + forward_coder_headers, created_by, updated_by ) VALUES ( @@ -13685,11 +13692,12 @@ INSERT INTO mcp_server_configs ( $22::boolean, $23::boolean, $24::boolean, - $25::uuid, - $26::uuid + $25::boolean, + $26::uuid, + $27::uuid ) RETURNING - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers ` type InsertMCPServerConfigParams struct { @@ -13717,6 +13725,7 @@ type InsertMCPServerConfigParams struct { Enabled bool `db:"enabled" json:"enabled"` ModelIntent bool `db:"model_intent" json:"model_intent"` AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` CreatedBy uuid.UUID `db:"created_by" json:"created_by"` UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` } @@ -13747,6 +13756,7 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer arg.Enabled, arg.ModelIntent, arg.AllowInPlanMode, + arg.ForwardCoderHeaders, arg.CreatedBy, arg.UpdatedBy, ) @@ -13781,6 +13791,7 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ) return i, err } @@ -13813,12 +13824,13 @@ SET enabled = $22::boolean, model_intent = $23::boolean, allow_in_plan_mode = $24::boolean, - updated_by = $25::uuid, + forward_coder_headers = $25::boolean, + updated_by = $26::uuid, updated_at = NOW() WHERE - id = $26::uuid + id = $27::uuid RETURNING - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers ` type UpdateMCPServerConfigParams struct { @@ -13846,6 +13858,7 @@ type UpdateMCPServerConfigParams struct { Enabled bool `db:"enabled" json:"enabled"` ModelIntent bool `db:"model_intent" json:"model_intent"` AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` ID uuid.UUID `db:"id" json:"id"` } @@ -13876,6 +13889,7 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer arg.Enabled, arg.ModelIntent, arg.AllowInPlanMode, + arg.ForwardCoderHeaders, arg.UpdatedBy, arg.ID, ) @@ -13910,6 +13924,7 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer &i.UpdatedAt, &i.ModelIntent, &i.AllowInPlanMode, + &i.ForwardCoderHeaders, ) return i, err } diff --git a/coderd/database/queries/mcpserverconfigs.sql b/coderd/database/queries/mcpserverconfigs.sql index 103bbaea17..3d05a2b102 100644 --- a/coderd/database/queries/mcpserverconfigs.sql +++ b/coderd/database/queries/mcpserverconfigs.sql @@ -79,6 +79,7 @@ INSERT INTO mcp_server_configs ( enabled, model_intent, allow_in_plan_mode, + forward_coder_headers, created_by, updated_by ) VALUES ( @@ -106,6 +107,7 @@ INSERT INTO mcp_server_configs ( @enabled::boolean, @model_intent::boolean, @allow_in_plan_mode::boolean, + @forward_coder_headers::boolean, @created_by::uuid, @updated_by::uuid ) @@ -140,6 +142,7 @@ SET enabled = @enabled::boolean, model_intent = @model_intent::boolean, allow_in_plan_mode = @allow_in_plan_mode::boolean, + forward_coder_headers = @forward_coder_headers::boolean, updated_by = @updated_by::uuid, updated_at = NOW() WHERE diff --git a/coderd/mcp.go b/coderd/mcp.go index b3b7d5619f..3e0a5829f7 100644 --- a/coderd/mcp.go +++ b/coderd/mcp.go @@ -283,6 +283,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, ModelIntent: req.ModelIntent, AllowInPlanMode: req.AllowInPlanMode, + ForwardCoderHeaders: req.ForwardCoderHeaders, CreatedBy: apiKey.UserID, UpdatedBy: apiKey.UserID, }) @@ -371,6 +372,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { Enabled: inserted.Enabled, ModelIntent: inserted.ModelIntent, AllowInPlanMode: inserted.AllowInPlanMode, + ForwardCoderHeaders: inserted.ForwardCoderHeaders, UpdatedBy: apiKey.UserID, }) if err != nil { @@ -440,6 +442,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, ModelIntent: req.ModelIntent, AllowInPlanMode: req.AllowInPlanMode, + ForwardCoderHeaders: req.ForwardCoderHeaders, CreatedBy: apiKey.UserID, UpdatedBy: apiKey.UserID, }) @@ -699,6 +702,11 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { allowInPlanMode = *req.AllowInPlanMode } + forwardCoderHeaders := existing.ForwardCoderHeaders + if req.ForwardCoderHeaders != nil { + forwardCoderHeaders = *req.ForwardCoderHeaders + } + // When auth_type changes, clear fields belonging to the // previous auth type so stale secrets don't persist. if authType != existing.AuthType { @@ -783,6 +791,7 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { Enabled: enabled, ModelIntent: modelIntent, AllowInPlanMode: allowInPlanMode, + ForwardCoderHeaders: forwardCoderHeaders, UpdatedBy: apiKey.UserID, ID: existing.ID, }) @@ -1264,11 +1273,12 @@ func convertMCPServerConfig(config database.MCPServerConfig) codersdk.MCPServerC Availability: config.Availability, - Enabled: config.Enabled, - ModelIntent: config.ModelIntent, - AllowInPlanMode: config.AllowInPlanMode, - CreatedAt: config.CreatedAt, - UpdatedAt: config.UpdatedAt, + Enabled: config.Enabled, + ModelIntent: config.ModelIntent, + AllowInPlanMode: config.AllowInPlanMode, + ForwardCoderHeaders: config.ForwardCoderHeaders, + CreatedAt: config.CreatedAt, + UpdatedAt: config.UpdatedAt, } } diff --git a/coderd/mcp_test.go b/coderd/mcp_test.go index 7b0ce137f8..add730960f 100644 --- a/coderd/mcp_test.go +++ b/coderd/mcp_test.go @@ -99,6 +99,7 @@ func TestMCPServerConfigsCRUD(t *testing.T) { require.Equal(t, "default_on", created.Availability) require.True(t, created.Enabled) require.False(t, created.AllowInPlanMode) + require.False(t, created.ForwardCoderHeaders) // Verify the secret is indicated but never returned. require.True(t, created.HasOAuth2Secret) @@ -110,25 +111,31 @@ func TestMCPServerConfigsCRUD(t *testing.T) { require.Equal(t, created.ID, configs[0].ID) require.True(t, configs[0].HasOAuth2Secret) require.False(t, configs[0].AllowInPlanMode) + require.False(t, configs[0].ForwardCoderHeaders) fetched, err := client.MCPServerConfigByID(ctx, created.ID) require.NoError(t, err) require.Equal(t, created.ID, fetched.ID) require.False(t, fetched.AllowInPlanMode) + require.False(t, fetched.ForwardCoderHeaders) - // Update display name, availability, and allow_in_plan_mode. + // Update display name, availability, allow_in_plan_mode, and + // forward_coder_headers. newName := "Renamed Server" newAvail := "force_on" allowInPlanMode := true + forwardCoderHeaders := true updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{ - DisplayName: &newName, - Availability: &newAvail, - AllowInPlanMode: &allowInPlanMode, + DisplayName: &newName, + Availability: &newAvail, + AllowInPlanMode: &allowInPlanMode, + ForwardCoderHeaders: &forwardCoderHeaders, }) require.NoError(t, err) require.Equal(t, "Renamed Server", updated.DisplayName) require.Equal(t, "force_on", updated.Availability) require.True(t, updated.AllowInPlanMode) + require.True(t, updated.ForwardCoderHeaders) // Unchanged fields should remain the same. require.Equal(t, "my-mcp-server", updated.Slug) require.Equal(t, "oauth2", updated.AuthType) @@ -140,10 +147,12 @@ func TestMCPServerConfigsCRUD(t *testing.T) { require.Equal(t, "Renamed Server", configs[0].DisplayName) require.Equal(t, "force_on", configs[0].Availability) require.True(t, configs[0].AllowInPlanMode) + require.True(t, configs[0].ForwardCoderHeaders) fetched, err = client.MCPServerConfigByID(ctx, created.ID) require.NoError(t, err) require.True(t, fetched.AllowInPlanMode) + require.True(t, fetched.ForwardCoderHeaders) // Delete it. err = client.DeleteMCPServerConfig(ctx, created.ID) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index bcb493eb64..8c8e20abe2 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -6977,6 +6977,7 @@ func (p *Server) runChat( mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens) mcpTools, mcpCleanup = mcpclient.ConnectAll( ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource, + chatprovider.CoderHeaders(chat), ) return nil }) diff --git a/coderd/x/chatd/mcpclient/coder_headers_test.go b/coderd/x/chatd/mcpclient/coder_headers_test.go new file mode 100644 index 0000000000..f90a031d5a --- /dev/null +++ b/coderd/x/chatd/mcpclient/coder_headers_test.go @@ -0,0 +1,329 @@ +package mcpclient_test + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" +) + +// newHeaderRecordingServer creates a streamable HTTP MCP server with a +// single "ping" tool. Every request's headers are appended to the +// returned slice so tests can assert which headers were forwarded. +func newHeaderRecordingServer(t *testing.T) (*httptest.Server, *sync.Mutex, *[]http.Header) { + t.Helper() + var ( + mu sync.Mutex + headers []http.Header + ) + srv := mcpserver.NewMCPServer("hdr-server", "1.0.0") + srv.AddTools(mcpserver.ServerTool{ + Tool: mcp.NewTool("ping", mcp.WithDescription("records the request headers")), + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mu.Lock() + headers = append(headers, req.Header.Clone()) + mu.Unlock() + return mcp.NewToolResultText("ok"), nil + }, + }) + httpSrv := mcpserver.NewStreamableHTTPServer(srv) + ts := httptest.NewServer(httpSrv) + t.Cleanup(ts.Close) + return ts, &mu, &headers +} + +// TestConnectAll_ForwardCoderHeaders_DefaultOff is a regression guard +// that the Coder identity headers are NOT sent when the option is +// left at its default (false). +func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + cfg := makeConfig("no-hdr", ts.URL) + assert.False(t, cfg.ForwardCoderHeaders, "default must be false") + + coderHeaders := map[string]string{ + chatprovider.HeaderCoderOwnerID: uuid.NewString(), + chatprovider.HeaderCoderChatID: uuid.NewString(), + chatprovider.HeaderCoderWorkspaceID: uuid.NewString(), + } + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "no-hdr__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + for _, h := range *recorded { + assert.Empty(t, h.Get(chatprovider.HeaderCoderOwnerID)) + assert.Empty(t, h.Get(chatprovider.HeaderCoderChatID)) + assert.Empty(t, h.Get(chatprovider.HeaderCoderSubchatID)) + assert.Empty(t, h.Get(chatprovider.HeaderCoderWorkspaceID)) + } +} + +// TestConnectAll_ForwardCoderHeaders_Enabled verifies that when the +// option is enabled, the Coder identity headers are forwarded on every +// outgoing MCP request, including the subchat and workspace headers. +func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + ownerID := uuid.New() + chatID := uuid.New() + workspaceID := uuid.New() + subchatID := uuid.New() + + cfg := makeConfig("hdr", ts.URL) + cfg.ForwardCoderHeaders = true + + // Subchat headers: parent's chat ID lives in X-Coder-Chat-Id, the + // subchat's own ID lives in X-Coder-Subchat-Id. + coderHeaders := chatprovider.CoderHeaders(database.Chat{ + ID: subchatID, + OwnerID: ownerID, + ParentChatID: uuid.NullUUID{UUID: chatID, Valid: true}, + WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, + }) + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "hdr__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + last := (*recorded)[len(*recorded)-1] + assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID)) + assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID)) + assert.Equal(t, subchatID.String(), last.Get(chatprovider.HeaderCoderSubchatID)) + assert.Equal(t, workspaceID.String(), last.Get(chatprovider.HeaderCoderWorkspaceID)) +} + +// TestConnectAll_ForwardCoderHeaders_RootChat verifies that for a root +// chat (no parent), the chat's own ID is forwarded as +// X-Coder-Chat-Id and the X-Coder-Subchat-Id header is absent. +func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + ownerID := uuid.New() + chatID := uuid.New() + + cfg := makeConfig("hdr-root", ts.URL) + cfg.ForwardCoderHeaders = true + + coderHeaders := chatprovider.CoderHeaders(database.Chat{ + ID: chatID, + OwnerID: ownerID, + }) + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "hdr-root__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + last := (*recorded)[len(*recorded)-1] + assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID)) + assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID)) + assert.Empty(t, last.Get(chatprovider.HeaderCoderSubchatID)) + assert.Empty(t, last.Get(chatprovider.HeaderCoderWorkspaceID)) +} + +// TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth verifies that the +// api_key auth header is preserved when Coder identity headers are +// forwarded alongside. +func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + ownerID := uuid.New() + chatID := uuid.New() + + cfg := makeConfig("hdr-apikey", ts.URL) + cfg.AuthType = "api_key" + cfg.APIKeyHeader = "X-Api-Key" + cfg.APIKeyValue = "sekret" + cfg.ForwardCoderHeaders = true + + coderHeaders := chatprovider.CoderHeaders(database.Chat{ + ID: chatID, + OwnerID: ownerID, + }) + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "hdr-apikey__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + last := (*recorded)[len(*recorded)-1] + assert.Equal(t, "sekret", last.Get("X-Api-Key")) + assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID)) + assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID)) +} + +// TestConnectAll_ForwardCoderHeaders_WithOAuth2 verifies that the +// oauth2 Authorization header is preserved when Coder identity +// headers are forwarded alongside, and that auth wins on a conflict. +func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + cfgID := uuid.New() + cfg := makeConfig("hdr-oauth", ts.URL) + cfg.ID = cfgID + cfg.AuthType = "oauth2" + cfg.ForwardCoderHeaders = true + token := database.MCPServerUserToken{ + MCPServerConfigID: cfgID, + AccessToken: "oauth-token-xyz", + TokenType: "Bearer", + } + + // Intentionally include an Authorization key to verify the auth + // header wins on conflict. + ownerID := uuid.NewString() + coderHeaders := map[string]string{ + "Authorization": "Bearer should-be-overridden", + chatprovider.HeaderCoderOwnerID: ownerID, + } + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, + []database.MCPServerConfig{cfg}, + []database.MCPServerUserToken{token}, + uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "hdr-oauth__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + last := (*recorded)[len(*recorded)-1] + assert.Equal(t, "Bearer oauth-token-xyz", last.Get("Authorization")) + assert.Equal(t, ownerID, last.Get(chatprovider.HeaderCoderOwnerID)) +} + +// TestConnectAll_ForwardCoderHeaders_WithCustomHeaders verifies that +// custom_headers admin-configured values are preserved when Coder +// identity headers are forwarded alongside, including the case where +// the admin configures a custom header whose name only differs from a +// Coder identity header by case. Conflict detection is case- +// insensitive because http.Header.Set canonicalizes header names. +func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) { + t.Parallel() + ctx := t.Context() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts, mu, recorded := newHeaderRecordingServer(t) + + ownerID := uuid.New() + chatID := uuid.New() + + cfg := makeConfig("hdr-custom", ts.URL) + cfg.AuthType = "custom_headers" + // Include both an unrelated custom header AND a case-variant of + // X-Coder-Owner-Id to exercise the case-insensitive conflict + // check. The admin-configured value MUST win. + cfg.CustomHeaders = `{"X-Tenant":"acme","x-coder-owner-id":"admin-controlled"}` + cfg.ForwardCoderHeaders = true + + coderHeaders := chatprovider.CoderHeaders(database.Chat{ + ID: chatID, + OwnerID: ownerID, + }) + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + coderHeaders, + ) + t.Cleanup(cleanup) + require.Len(t, tools, 1) + + _, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-1", Name: "hdr-custom__ping", Input: "{}", + }) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, *recorded) + last := (*recorded)[len(*recorded)-1] + assert.Equal(t, "acme", last.Get("X-Tenant")) + // The admin's case-variant header must win, because HTTP header + // names are case-insensitive at the transport level. + assert.Equal(t, "admin-controlled", last.Get(chatprovider.HeaderCoderOwnerID)) + assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID)) +} diff --git a/coderd/x/chatd/mcpclient/mcpclient.go b/coderd/x/chatd/mcpclient/mcpclient.go index 8b57e9b3a0..16ef5ed9fd 100644 --- a/coderd/x/chatd/mcpclient/mcpclient.go +++ b/coderd/x/chatd/mcpclient/mcpclient.go @@ -74,6 +74,7 @@ func ConnectAll( tokens []database.MCPServerUserToken, userID uuid.UUID, oidcSrc UserOIDCTokenSource, + coderHeaders map[string]string, ) ([]fantasy.AgentTool, func()) { // Index tokens by server config ID so auth header // construction is O(1) per server. @@ -109,7 +110,7 @@ func ConnectAll( eg.Go(func() error { serverTools, mcpClient, connectErr := connectOne( - ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, + ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, coderHeaders, ) if connectErr != nil { logger.Warn(ctx, @@ -175,9 +176,31 @@ func connectOne( tokensByConfigID map[uuid.UUID]database.MCPServerUserToken, userID uuid.UUID, oidcSrc UserOIDCTokenSource, + coderHeaders map[string]string, ) ([]fantasy.AgentTool, *client.Client, error) { headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc) + // When opted-in, merge Coder identity headers BEFORE the + // transport is created so any auth header already set above + // wins on a conflict. Conflict detection uses + // http.CanonicalHeaderKey because the upstream transport applies + // http.Header.Set, which canonicalizes keys; without that, an + // admin-configured header that differs only in case from a Coder + // identity header would land in the request map twice and the + // surviving value would be non-deterministic. + if cfg.ForwardCoderHeaders { + canonicalAuth := make(map[string]struct{}, len(headers)) + for k := range headers { + canonicalAuth[http.CanonicalHeaderKey(k)] = struct{}{} + } + for k, v := range coderHeaders { + if _, exists := canonicalAuth[http.CanonicalHeaderKey(k)]; exists { + continue + } + headers[k] = v + } + } + tr, err := createTransport(cfg, headers) if err != nil { return nil, nil, xerrors.Errorf( diff --git a/coderd/x/chatd/mcpclient/mcpclient_test.go b/coderd/x/chatd/mcpclient/mcpclient_test.go index dca1c5a1b8..d91788fd2f 100644 --- a/coderd/x/chatd/mcpclient/mcpclient_test.go +++ b/coderd/x/chatd/mcpclient/mcpclient_test.go @@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) { ts := newTestMCPServer(t, echoTool(), greetTool()) cfg := makeConfig("myserver", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) // Two tools should be discovered, namespaced with the server slug. @@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) { // Only allow the "echo" tool. cfg.ToolAllowList = []string{"echo"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) { // Deny the "greet" tool, so only "echo" remains. cfg.ToolDenyList = []string{"greet"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) { cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist") - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools, "no tools should be returned for an unreachable server") @@ -201,6 +201,7 @@ func TestConnectAll_MultipleServers(t *testing.T) { []database.MCPServerConfig{cfg1, cfg2}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -227,6 +228,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) { []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + nil, ) require.Empty(t, tools) @@ -255,6 +257,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { }, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -284,6 +287,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { }, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -317,6 +321,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { []database.MCPServerConfig{cfg1, cfg2}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -381,6 +386,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) { []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -435,7 +441,7 @@ func TestConnectAll_DisabledServer(t *testing.T) { cfg := makeConfig("disabled", ts.URL) cfg.Enabled = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools) } @@ -450,7 +456,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -475,7 +481,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -517,7 +523,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) { ts := newTestMCPServer(t, noRequiredTool) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -570,6 +576,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -627,6 +634,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -664,6 +672,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -722,7 +731,7 @@ func TestConnectAll_UserOIDCAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, - userID, src, + userID, src, nil, ) t.Cleanup(cleanup) @@ -781,7 +790,7 @@ func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, - uuid.New(), src, + uuid.New(), src, nil, ) t.Cleanup(cleanup) @@ -817,7 +826,7 @@ func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, - uuid.New(), nil, + uuid.New(), nil, nil, ) t.Cleanup(cleanup) @@ -846,6 +855,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) { []database.MCPServerConfig{cfg1, cfg2, cfg3}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -906,7 +916,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) { Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil) t.Cleanup(cleanup) // The server accepts any auth, so the tool is still discovered @@ -939,7 +949,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) { TokenType: "Bearer", } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil, nil) t.Cleanup(cleanup) // Tool is still discovered (server doesn't require auth), but @@ -969,7 +979,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) { Enabled: true, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1016,6 +1026,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) { []database.MCPServerConfig{cfg1, cfg2}, nil, uuid.Nil, nil, + nil, ) t.Cleanup(cleanup) @@ -1072,7 +1083,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-txt", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1139,7 +1150,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-blob", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1219,7 +1230,7 @@ func TestConnectAll_ResourceLink(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("res-link", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1263,7 +1274,7 @@ func TestConnectAll_CallToolError(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("err-srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1287,7 +1298,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) { cfg := makeConfig("intent-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1323,7 +1334,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) { cfg := makeConfig("no-intent", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1346,7 +1357,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) { cfg := makeConfig("unwrap-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1371,7 +1382,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) { cfg := makeConfig("flat-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1396,7 +1407,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) { cfg := makeConfig("pass-srv", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1421,7 +1432,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) { cfg := makeConfig("bad-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) diff --git a/codersdk/mcp.go b/codersdk/mcp.go index 132c804479..f3d1bd1175 100644 --- a/codersdk/mcp.go +++ b/codersdk/mcp.go @@ -64,11 +64,18 @@ type MCPServerConfig struct { // Availability policy set by admin. Availability string `json:"availability"` // "force_on", "default_on", "default_off" - Enabled bool `json:"enabled"` - ModelIntent bool `json:"model_intent"` - AllowInPlanMode bool `json:"allow_in_plan_mode"` - CreatedAt time.Time `json:"created_at" format:"date-time"` - UpdatedAt time.Time `json:"updated_at" format:"date-time"` + Enabled bool `json:"enabled"` + ModelIntent bool `json:"model_intent"` + AllowInPlanMode bool `json:"allow_in_plan_mode"` + + // ForwardCoderHeaders forwards the same Coder identity headers we + // send to LLM providers (X-Coder-Owner-Id, X-Coder-Chat-Id, and the + // optional X-Coder-Subchat-Id and X-Coder-Workspace-Id) to this + // MCP server on every request. Off by default to avoid leaking + // chat identity to third-party servers. + ForwardCoderHeaders bool `json:"forward_coder_headers"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` // Per-user state (populated for non-admin requests). AuthConnected bool `json:"auth_connected"` @@ -101,6 +108,10 @@ type CreateMCPServerConfigRequest struct { Enabled bool `json:"enabled"` ModelIntent bool `json:"model_intent"` AllowInPlanMode bool `json:"allow_in_plan_mode"` + + // ForwardCoderHeaders, when true, forwards Coder identity + // headers on every outgoing MCP request. See MCPServerConfig. + ForwardCoderHeaders bool `json:"forward_coder_headers"` } // UpdateMCPServerConfigRequest is the request to update an MCP server config. @@ -130,6 +141,10 @@ type UpdateMCPServerConfigRequest struct { Enabled *bool `json:"enabled,omitempty"` ModelIntent *bool `json:"model_intent,omitempty"` AllowInPlanMode *bool `json:"allow_in_plan_mode,omitempty"` + + // ForwardCoderHeaders, when set, updates whether Coder identity + // headers are forwarded on every outgoing MCP request. + ForwardCoderHeaders *bool `json:"forward_coder_headers,omitempty"` } func (c *Client) MCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) { diff --git a/docs/ai-coder/agents/platform-controls/mcp-servers.md b/docs/ai-coder/agents/platform-controls/mcp-servers.md index 86e751625d..15b3b5f219 100644 --- a/docs/ai-coder/agents/platform-controls/mcp-servers.md +++ b/docs/ai-coder/agents/platform-controls/mcp-servers.md @@ -33,11 +33,12 @@ This is an admin-only feature accessible at **Agents** > **Settings** > ### Availability -| Field | Required | Description | -|----------------|----------|-------------------------------------------------------------------------------------------------------------------------------| -| `enabled` | No | Master toggle. Disabled servers are hidden from non-admin users. | -| `availability` | Yes | Controls how the server appears in chat sessions. See [Availability policies](#availability-policies). | -| `model_intent` | No | When enabled, requires the model to describe each tool call's purpose in natural language, shown as a status label in the UI. | +| Field | Required | Description | +|-------------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------| +| `enabled` | No | Master toggle. Disabled servers are hidden from non-admin users. | +| `availability` | Yes | Controls how the server appears in chat sessions. See [Availability policies](#availability-policies). | +| `model_intent` | No | When enabled, requires the model to describe each tool call's purpose in natural language, shown as a status label in the UI. | +| `forward_coder_headers` | No | When enabled, forwards Coder identity headers on every outgoing MCP request. See [Coder identity headers](#coder-identity-headers). | #### Availability policies @@ -129,6 +130,30 @@ Control which tools from a server are available in chat: | `tool_allow_list` | If non-empty, only the listed tool names are exposed. An empty list allows all tools. | | `tool_deny_list` | Listed tool names are always blocked, even if they appear in the allow list. | +## Coder identity headers + +MCP servers configured with `forward_coder_headers = true` receive the +following identity headers on every outgoing request, alongside the +auth header for the configured `auth_type`: + +| Header | Description | +|------------------------|--------------------------------------------------------------------------------------------------------------| +| `X-Coder-Owner-Id` | Coder user who owns the chat that issued the tool call. | +| `X-Coder-Chat-Id` | Top-level (parent) chat ID. For root chats this is the chat's own ID; for subchats it is the parent chat ID. | +| `X-Coder-Subchat-Id` | Subchat ID. Only present when the request originates from a child chat. | +| `X-Coder-Workspace-Id` | Workspace associated with the chat, if any. | + +These are the same headers Coder sends to LLM providers (see +[Coder agents headers](../../ai-gateway/clients/coder-agents.md)) so a +first-party MCP server can correlate a tool call back to the +originating chat. + +Because the headers leak chat identity, the option is **off by +default** and should only be enabled for first-party or trusted +internal MCP servers. If the auth header for the configured +`auth_type` collides with one of these headers, the auth header +wins. + ## Permissions | Action | Required role | diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 51d265e8d2..2f913e20a8 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -3071,6 +3071,11 @@ export interface CreateMCPServerConfigRequest { readonly enabled: boolean; readonly model_intent: boolean; readonly allow_in_plan_mode: boolean; + /** + * ForwardCoderHeaders, when true, forwards Coder identity + * headers on every outgoing MCP request. See MCPServerConfig. + */ + readonly forward_coder_headers: boolean; } // From codersdk/organizations.go @@ -4775,6 +4780,14 @@ export interface MCPServerConfig { readonly enabled: boolean; readonly model_intent: boolean; readonly allow_in_plan_mode: boolean; + /** + * ForwardCoderHeaders forwards the same Coder identity headers we + * send to LLM providers (X-Coder-Owner-Id, X-Coder-Chat-Id, and the + * optional X-Coder-Subchat-Id and X-Coder-Workspace-Id) to this + * MCP server on every request. Off by default to avoid leaking + * chat identity to third-party servers. + */ + readonly forward_coder_headers: boolean; readonly created_at: string; readonly updated_at: string; /** @@ -8230,6 +8243,11 @@ export interface UpdateMCPServerConfigRequest { readonly enabled?: boolean; readonly model_intent?: boolean; readonly allow_in_plan_mode?: boolean; + /** + * ForwardCoderHeaders, when set, updates whether Coder identity + * headers are forwarded on every outgoing MCP request. + */ + readonly forward_coder_headers?: boolean; } // From codersdk/notifications.go diff --git a/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx b/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx index 9145231dd1..02d52ee3f3 100644 --- a/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx @@ -707,6 +707,7 @@ const makeMCPServer = ( enabled: overrides.enabled ?? true, model_intent: overrides.model_intent ?? false, allow_in_plan_mode: overrides.allow_in_plan_mode ?? false, + forward_coder_headers: overrides.forward_coder_headers ?? false, created_at: overrides.created_at ?? now, updated_at: overrides.updated_at ?? now, auth_connected: overrides.auth_connected ?? false, diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx index 629f0f5800..1f21043869 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.stories.tsx @@ -757,6 +757,7 @@ const sampleMCPServers = [ enabled: true, model_intent: false, allow_in_plan_mode: false, + forward_coder_headers: false, auth_connected: true, created_at: "2025-01-01T00:00:00Z", updated_at: "2025-01-01T00:00:00Z", diff --git a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx index 58a605bb2b..5cbf0923b2 100644 --- a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx @@ -34,6 +34,7 @@ const createServerConfig = ( enabled: overrides.enabled ?? true, model_intent: overrides.model_intent ?? false, allow_in_plan_mode: overrides.allow_in_plan_mode ?? false, + forward_coder_headers: overrides.forward_coder_headers ?? false, created_at: overrides.created_at ?? now, updated_at: overrides.updated_at ?? now, auth_connected: overrides.auth_connected ?? false, diff --git a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx index a45ea6aa93..5ae2659f90 100644 --- a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx +++ b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx @@ -347,6 +347,7 @@ interface MCPServerFormValues { enabled: boolean; modelIntent: boolean; allowInPlanMode: boolean; + forwardCoderHeaders: boolean; toolAllowList: string; toolDenyList: string; customHeaders: Array<{ key: string; value: string }>; @@ -377,6 +378,7 @@ const buildInitialValues = ( enabled: server?.enabled ?? true, modelIntent: server?.model_intent ?? false, allowInPlanMode: server?.allow_in_plan_mode ?? false, + forwardCoderHeaders: server?.forward_coder_headers ?? false, toolAllowList: joinList(server?.tool_allow_list), toolDenyList: joinList(server?.tool_deny_list), customHeaders: [], @@ -435,6 +437,7 @@ const ServerForm: FC = ({ enabled: values.enabled, model_intent: values.modelIntent, allow_in_plan_mode: values.allowInPlanMode, + forward_coder_headers: values.forwardCoderHeaders, ...(values.authType === "oauth2" && { oauth2_client_id: values.oauth2ClientID.trim(), oauth2_client_secret: effectiveOAuth2Secret, @@ -933,7 +936,8 @@ const ServerForm: FC = ({ Behavior

- Availability, model intent, and tool governance. + Availability, model intent, identity headers, and tool + governance.

{showBehavior ? ( @@ -1017,6 +1021,31 @@ const ServerForm: FC = ({ /> +
+
+ +

+ When enabled, every outgoing MCP request includes the + Coder owner, chat, subchat, and workspace IDs as + X-Coder-* headers. Off by default. Only + enable for first-party or trusted MCP servers. +

+
+ { + form.setFieldValue("forwardCoderHeaders", v); + }} + disabled={isDisabled} + /> +
+