diff --git a/coderd/coderd.go b/coderd/coderd.go index 8921087c73..f2410cfff2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -93,6 +93,7 @@ import ( "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/coderd/x/chatd" + "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" "github.com/coder/coder/v2/coderd/x/gitsync" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/drpcsdk" @@ -777,6 +778,14 @@ func New(options *Options) *API { maxChatsPerAcquire = math.MinInt32 } + var oidcMCPSrc mcpclient.UserOIDCTokenSource + if options.OIDCConfig != nil { + oidcMCPSrc = newOIDCMCPTokenSource( + options.Database, + options.OIDCConfig, + options.Logger.Named("mcp-user-oidc"), + ) + } api.chatDaemon = chatd.New(chatd.Config{ Logger: options.Logger.Named("chatd"), Database: options.Database, @@ -794,6 +803,7 @@ func New(options *Options) *API { WebpushDispatcher: options.WebPushDispatcher, UsageTracker: options.WorkspaceUsageTracker, PrometheusRegistry: options.PrometheusRegistry, + OIDCTokenSource: oidcMCPSrc, }).Start() gitSyncLogger := options.Logger.Named("gitsync") refresher := gitsync.NewRefresher( diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 01ed07de6b..99109c9486 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1758,7 +1758,7 @@ CREATE TABLE mcp_server_configs ( updated_at timestamp with time zone DEFAULT now() NOT NULL, model_intent boolean DEFAULT false NOT NULL, allow_in_plan_mode boolean DEFAULT false NOT NULL, - CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))), + CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text, 'user_oidc'::text]))), CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))), CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text]))) ); diff --git a/coderd/database/migrations/000484_mcp_user_oidc_auth.down.sql b/coderd/database/migrations/000484_mcp_user_oidc_auth.down.sql new file mode 100644 index 0000000000..245e0060c4 --- /dev/null +++ b/coderd/database/migrations/000484_mcp_user_oidc_auth.down.sql @@ -0,0 +1,10 @@ +-- Rolling this migration back deletes any rows using the user_oidc auth +-- type because they would otherwise violate the restored CHECK constraint. +DELETE FROM mcp_server_configs WHERE auth_type = 'user_oidc'; + +ALTER TABLE mcp_server_configs + DROP CONSTRAINT mcp_server_configs_auth_type_check; + +ALTER TABLE mcp_server_configs + ADD CONSTRAINT mcp_server_configs_auth_type_check + CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers')); diff --git a/coderd/database/migrations/000484_mcp_user_oidc_auth.up.sql b/coderd/database/migrations/000484_mcp_user_oidc_auth.up.sql new file mode 100644 index 0000000000..cb27a30cef --- /dev/null +++ b/coderd/database/migrations/000484_mcp_user_oidc_auth.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE mcp_server_configs + DROP CONSTRAINT mcp_server_configs_auth_type_check; + +ALTER TABLE mcp_server_configs + ADD CONSTRAINT mcp_server_configs_auth_type_check + CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers', 'user_oidc')); diff --git a/coderd/mcp.go b/coderd/mcp.go index 1059530629..b3b7d5619f 100644 --- a/coderd/mcp.go +++ b/coderd/mcp.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -21,14 +22,124 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" "github.com/coder/coder/v2/codersdk" ) +// oidcMCPTokenSource implements mcpclient.UserOIDCTokenSource using +// the same refresh strategy as provisionerdserver.ObtainOIDCAccessToken. +// The logic is duplicated to avoid importing provisionerdserver from +// coderd; keep the two in sync. +type oidcMCPTokenSource struct { + db database.Store + config promoauth.OAuth2Config + logger slog.Logger +} + +// newOIDCMCPTokenSource returns nil when no OIDC provider is +// configured. mcpclient treats a nil source the same as "no token +// available" and omits the Authorization header. +func newOIDCMCPTokenSource(db database.Store, config promoauth.OAuth2Config, logger slog.Logger) mcpclient.UserOIDCTokenSource { + if config == nil { + return nil + } + return &oidcMCPTokenSource{ + db: db, + config: config, + logger: logger, + } +} + +// OIDCAccessToken implements mcpclient.UserOIDCTokenSource. It +// refreshes expired tokens and persists the refreshed token back +// to user_links. The chatd dbauthz subject does not grant +// ResourceSystem.Read or ResourceUser.UpdatePersonal, so DB calls +// elevate to AsSystemRestricted; the per-user authorization is +// already enforced by the API handler that owns ctx. +func (s *oidcMCPTokenSource) OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error) { + //nolint:gocritic // user_links read needs system access; the + // caller's user identity is supplied via the userID parameter. + dbCtx := dbauthz.AsSystemRestricted(ctx) + link, err := s.db.GetUserLinkByUserIDLoginType(dbCtx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: userID, + LoginType: database.LoginTypeOIDC, + }) + if errors.Is(err, sql.ErrNoRows) { + return "", nil + } + if err != nil { + return "", xerrors.Errorf("get oidc user link: %w", err) + } + + if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh { + token, err := s.config.TokenSource(ctx, &oauth2.Token{ + AccessToken: link.OAuthAccessToken, + RefreshToken: link.OAuthRefreshToken, + // Use the expiresAt returned by shouldRefreshOIDCToken. + // It will force a refresh with an expired time. + Expiry: expiresAt, + }).Token() + if err != nil { + // Don't fail the request; the upstream MCP server will see no + // Authorization header and can return a 401 if it requires one. + s.logger.Warn(ctx, "failed to refresh OIDC token for MCP request", + slog.F("user_id", userID), + slog.Error(err), + ) + return "", nil + } + link.OAuthAccessToken = token.AccessToken + link.OAuthRefreshToken = token.RefreshToken + link.OAuthExpiry = token.Expiry + + // Persist on a detached context so a canceled chat request + // cannot drop a refresh-token rotation, see PR #24332. + persistCtx, persistCancel := context.WithTimeout( + context.WithoutCancel(dbCtx), 10*time.Second, + ) + link, err = s.db.UpdateUserLink(persistCtx, database.UpdateUserLinkParams{ + UserID: userID, + LoginType: database.LoginTypeOIDC, + OAuthAccessToken: link.OAuthAccessToken, + OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required + OAuthExpiry: link.OAuthExpiry, + Claims: link.Claims, + }) + persistCancel() + if err != nil { + return "", xerrors.Errorf("update user link after oidc refresh: %w", err) + } + s.logger.Info(ctx, "refreshed expired OIDC token for MCP request", + slog.F("user_id", userID), + ) + } + + return link.OAuthAccessToken, nil +} + +// shouldRefreshOIDCToken mirrors provisionerdserver.shouldRefreshOIDCToken. +// See that function for the rationale behind the 10-minute pre-expiry +// buffer. +func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) { + if link.OAuthRefreshToken == "" { + return false, link.OAuthExpiry + } + if link.OAuthExpiry.IsZero() { + // A zero expiry means the token never expires. + return false, link.OAuthExpiry + } + expiresAt := link.OAuthExpiry.Add(-time.Minute * 10) + return expiresAt.Before(dbtime.Now()), expiresAt +} + // @Summary List MCP server configs // @x-apidocgen {"skip": true} // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -629,6 +740,21 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { apiKeyHeader = "" apiKeyValue = "" apiKeyValueKeyID = sql.NullString{} + case "user_oidc": + // user_oidc forwards the calling user's OIDC access token + // from user_links at request time, so no admin-configured + // secrets are stored on the row. + oauth2ClientID = "" + oauth2ClientSecret = "" + oauth2ClientSecretKeyID = sql.NullString{} + oauth2AuthURL = "" + oauth2TokenURL = "" + oauth2Scopes = "" + apiKeyHeader = "" + apiKeyValue = "" + apiKeyValueKeyID = sql.NullString{} + customHeaders = "{}" + customHeadersKeyID = sql.NullString{} } } diff --git a/coderd/mcp_internal_test.go b/coderd/mcp_internal_test.go new file mode 100644 index 0000000000..8c757a638d --- /dev/null +++ b/coderd/mcp_internal_test.go @@ -0,0 +1,216 @@ +package coderd + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/testutil" +) + +// dbauthzTestStore wraps the test database with the same dbauthz layer +// used in production (coderd.go:370). Without it the test would not +// catch RBAC failures from the chatd subject; with it the test fails +// loudly if the elevation in OIDCAccessToken is removed or weakened. +func dbauthzTestStore(t *testing.T, db database.Store) database.Store { + t.Helper() + + authz := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) + acs := &atomic.Pointer[dbauthz.AccessControlStore]{} + var tacs dbauthz.AccessControlStore = fakeAccessControlStore{} + acs.Store(&tacs) + return dbauthz.New(db, authz, testutil.Logger(t), acs) +} + +// fakeAccessControlStore mirrors coderdtest.FakeAccessControlStore but is +// inlined here to avoid an import cycle (coderdtest imports coderd). +type fakeAccessControlStore struct{} + +func (fakeAccessControlStore) GetTemplateAccessControl(t database.Template) dbauthz.TemplateAccessControl { + return dbauthz.TemplateAccessControl{ + RequireActiveVersion: t.RequireActiveVersion, + } +} + +func (fakeAccessControlStore) SetTemplateAccessControl(context.Context, database.Store, uuid.UUID, dbauthz.TemplateAccessControl) error { + panic("not implemented") +} + +func TestShouldRefreshOIDCToken(t *testing.T) { + t.Parallel() + + now := dbtime.Now() + cases := []struct { + name string + link database.UserLink + want bool + }{ + { + name: "NoRefreshToken", + link: database.UserLink{OAuthExpiry: now.Add(-time.Hour)}, + }, + { + name: "ZeroExpiry", + link: database.UserLink{OAuthRefreshToken: "refresh"}, + }, + { + name: "Expired", + link: database.UserLink{ + OAuthRefreshToken: "refresh", + OAuthExpiry: now.Add(-time.Hour), + }, + want: true, + }, + { + name: "Fresh", + link: database.UserLink{ + OAuthRefreshToken: "refresh", + OAuthExpiry: now.Add(time.Hour), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, _ := shouldRefreshOIDCToken(tc.link) + require.Equal(t, tc.want, got) + }) + } +} + +func TestOIDCMCPTokenSource(t *testing.T) { + t.Parallel() + + logger := testutil.Logger(t) + + t.Run("NilConfig", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + require.Nil(t, newOIDCMCPTokenSource(db, nil, logger)) + }) + + t.Run("NoLink", func(t *testing.T) { + // When the user has no OIDC link the source returns ("", nil) + // rather than an error so the caller can fall through to + // "no Authorization header". + t.Parallel() + db, _ := dbtestutil.NewDB(t) + store := dbauthzTestStore(t, db) + user := dbgen.User(t, db, database.User{LoginType: database.LoginTypeOIDC}) + + src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{}, logger) + ctx := dbauthz.AsChatd(context.Background()) + + tok, err := src.OIDCAccessToken(ctx, user.ID) + require.NoError(t, err) + require.Empty(t, tok) + }) + + t.Run("FreshToken", func(t *testing.T) { + // A non-expired token is returned as-is; no refresh is performed. + t.Parallel() + db, _ := dbtestutil.NewDB(t) + store := dbauthzTestStore(t, db) + user := dbgen.User(t, db, database.User{}) + dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + OAuthAccessToken: "fresh", + OAuthRefreshToken: "refresh", + OAuthExpiry: dbtime.Now().Add(time.Hour), + }) + + src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{ + Token: &oauth2.Token{AccessToken: "should-not-be-used"}, + }, logger) + ctx := dbauthz.AsChatd(context.Background()) + + tok, err := src.OIDCAccessToken(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "fresh", tok) + }) + + t.Run("RefreshExpired", func(t *testing.T) { + // An expired token triggers a refresh; the new token is + // persisted via UpdateUserLink. This exercises the dbauthz + // elevation: chatd lacks ResourceSystem.Read and + // ResourceUser.UpdatePersonal so a non-elevated context + // would fail both reads and writes. + t.Parallel() + db, _ := dbtestutil.NewDB(t) + store := dbauthzTestStore(t, db) + user := dbgen.User(t, db, database.User{}) + dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + OAuthAccessToken: "stale", + OAuthRefreshToken: "refresh", + OAuthExpiry: dbtime.Now().Add(-time.Hour), + }) + + src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{ + Token: &oauth2.Token{ + AccessToken: "fresh", + RefreshToken: "new-refresh", + Expiry: dbtime.Now().Add(time.Hour), + }, + }, logger) + ctx := dbauthz.AsChatd(context.Background()) + + tok, err := src.OIDCAccessToken(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "fresh", tok) + + // Verify the refresh was persisted via UpdateUserLink. + got, err := db.GetUserLinkByUserIDLoginType( + dbauthz.AsSystemRestricted(context.Background()), + database.GetUserLinkByUserIDLoginTypeParams{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + }, + ) + require.NoError(t, err) + require.Equal(t, "fresh", got.OAuthAccessToken) + require.Equal(t, "new-refresh", got.OAuthRefreshToken) + }) + + t.Run("RefreshFailureReturnsEmpty", func(t *testing.T) { + // A refresh attempt that fails (e.g. invalid client config) + // must not surface an error to the caller; per the + // UserOIDCTokenSource contract this is treated as "no + // Authorization header". + t.Parallel() + db, _ := dbtestutil.NewDB(t) + store := dbauthzTestStore(t, db) + user := dbgen.User(t, db, database.User{}) + dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + OAuthAccessToken: "stale", + OAuthRefreshToken: "refresh", + OAuthExpiry: dbtime.Now().Add(-time.Hour), + }) + + // An empty oauth2.Config triggers a refresh failure + // because it has no token endpoint to call. + src := newOIDCMCPTokenSource(store, &oauth2.Config{}, logger) + ctx := dbauthz.AsChatd(context.Background()) + + tok, err := src.OIDCAccessToken(ctx, user.ID) + require.NoError(t, err) + require.Empty(t, tok) + }) +} diff --git a/coderd/mcp_test.go b/coderd/mcp_test.go index 1ac13b42d4..60ebf7c551 100644 --- a/coderd/mcp_test.go +++ b/coderd/mcp_test.go @@ -316,18 +316,109 @@ func TestMCPServerConfigsAuthConnected(t *testing.T) { // Also create a non-oauth server. It should report // auth_connected=true because no auth is needed. _ = createMCPServerConfig(t, adminClient, "no-auth-server", true) + + // And a user_oidc server. user_oidc never requires a per-user + // connect step, so auth_connected is always true regardless of + // whether the calling user has an OIDC link. + _, err = adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "User OIDC Server", + Slug: "user-oidc-server", + Transport: "streamable_http", + URL: "https://mcp.example.com/oidc", + AuthType: "user_oidc", + Availability: "default_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + memberConfigs, err = memberClient.MCPServerConfigs(ctx) require.NoError(t, err) - require.Len(t, memberConfigs, 2) + require.Len(t, memberConfigs, 3) for _, cfg := range memberConfigs { - if cfg.AuthType == "none" { - require.True(t, cfg.AuthConnected) - } else { - require.False(t, cfg.AuthConnected) + switch cfg.AuthType { + case "none", "user_oidc": + require.True(t, cfg.AuthConnected, "%s should report auth_connected", cfg.AuthType) + default: + require.False(t, cfg.AuthConnected, "%s should not report auth_connected", cfg.AuthType) } } } +func TestMCPServerConfigsUserOIDCClearsFields(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + // Start with an oauth2 config that has a client secret, then + // switch the auth_type to user_oidc and verify all auth-specific + // fields are cleared. + created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Switch Server", + Slug: "switch-server", + Transport: "streamable_http", + URL: "https://mcp.example.com/v1", + AuthType: "oauth2", + OAuth2ClientID: "cid", + OAuth2ClientSecret: "secret-value", + OAuth2AuthURL: "https://auth.example.com/authorize", + OAuth2TokenURL: "https://auth.example.com/token", + OAuth2Scopes: "read write", + Availability: "default_off", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + require.True(t, created.HasOAuth2Secret) + require.Equal(t, "cid", created.OAuth2ClientID) + + newAuth := "user_oidc" + updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{ + AuthType: &newAuth, + }) + require.NoError(t, err) + require.Equal(t, "user_oidc", updated.AuthType) + require.False(t, updated.HasOAuth2Secret, "oauth2 secret should be cleared") + require.False(t, updated.HasAPIKey, "api key should remain unset") + require.False(t, updated.HasCustomHeaders, "custom headers should remain unset") + require.Empty(t, updated.OAuth2ClientID) + require.Empty(t, updated.OAuth2AuthURL) + require.Empty(t, updated.OAuth2TokenURL) + require.Empty(t, updated.OAuth2Scopes) + require.Empty(t, updated.APIKeyHeader) +} + +func TestMCPServerConfigsUserOIDCDirect(t *testing.T) { + t.Parallel() + + // Create with user_oidc and confirm validation accepts the value + // while no auth-specific fields are persisted on the row. + ctx := testutil.Context(t, testutil.WaitLong) + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "User OIDC Direct", + Slug: "user-oidc-direct", + Transport: "streamable_http", + URL: "https://mcp.example.com/oidc-direct", + AuthType: "user_oidc", + Availability: "default_off", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + require.Equal(t, "user_oidc", created.AuthType) + require.False(t, created.HasOAuth2Secret) + require.False(t, created.HasAPIKey) + require.False(t, created.HasCustomHeaders) +} + func TestMCPServerConfigsAvailability(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 9648ebd4dd..7aa4426417 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -162,6 +162,7 @@ type Server struct { pubsub pubsub.Pubsub webpushDispatcher webpush.Dispatcher providerAPIKeys chatprovider.ProviderAPIKeys + oidcTokenSource mcpclient.UserOIDCTokenSource debugSvc *chatdebug.Service debugSvcFactory func() *chatdebug.Service debugSvcReady atomic.Bool @@ -3840,6 +3841,12 @@ type Config struct { UsageTracker *workspacestats.UsageTracker Clock quartz.Clock PrometheusRegistry prometheus.Registerer + + // OIDCTokenSource resolves the calling user's OIDC access + // token for MCP servers configured with auth_type=user_oidc. + // May be nil if the deployment has no OIDC provider; servers + // using user_oidc will then send no Authorization header. + OIDCTokenSource mcpclient.UserOIDCTokenSource } // New creates a new chat processor. The processor polls for pending @@ -3898,6 +3905,7 @@ func New(cfg Config) *Server { pubsub: cfg.Pubsub, webpushDispatcher: cfg.WebpushDispatcher, providerAPIKeys: cfg.ProviderAPIKeys, + oidcTokenSource: cfg.OIDCTokenSource, debugSvcFactory: func() *chatdebug.Service { debugSvc := chatdebug.NewService( cfg.Database, @@ -6472,7 +6480,7 @@ func (p *Server) runChat( // Refresh expired OAuth2 tokens before connecting. mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens) mcpTools, mcpCleanup = mcpclient.ConnectAll( - ctx, logger, mcpConnectConfigs, mcpTokens, + ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource, ) return nil }) diff --git a/coderd/x/chatd/mcpclient/mcpclient.go b/coderd/x/chatd/mcpclient/mcpclient.go index ce6755cfa6..8b57e9b3a0 100644 --- a/coderd/x/chatd/mcpclient/mcpclient.go +++ b/coderd/x/chatd/mcpclient/mcpclient.go @@ -49,6 +49,18 @@ const connectTimeout = 10 * time.Second // take before being canceled. const toolCallTimeout = 60 * time.Second +// UserOIDCTokenSource resolves the OIDC access token for the calling +// user. Implementations attempt to refresh tokens that are expired +// or close to expiring and MUST return ("", nil) when the user has +// no OIDC link or a refresh attempt failed for any reason. A +// non-nil error is reserved for unexpected infrastructure failures +// (e.g. database errors) and skips header construction entirely. +// The empty-token-on-refresh-failure behavior matches +// provisionerdserver.ObtainOIDCAccessToken. +type UserOIDCTokenSource interface { + OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error) +} + // ConnectAll connects to all configured MCP servers, discovers // their tools, and returns them as fantasy.AgentTool values. // Tools are sorted by their prefixed name so callers @@ -60,6 +72,8 @@ func ConnectAll( logger slog.Logger, configs []database.MCPServerConfig, tokens []database.MCPServerUserToken, + userID uuid.UUID, + oidcSrc UserOIDCTokenSource, ) ([]fantasy.AgentTool, func()) { // Index tokens by server config ID so auth header // construction is O(1) per server. @@ -95,7 +109,7 @@ func ConnectAll( eg.Go(func() error { serverTools, mcpClient, connectErr := connectOne( - ctx, logger, cfg, tokensByConfigID, + ctx, logger, cfg, tokensByConfigID, userID, oidcSrc, ) if connectErr != nil { logger.Warn(ctx, @@ -159,8 +173,10 @@ func connectOne( logger slog.Logger, cfg database.MCPServerConfig, tokensByConfigID map[uuid.UUID]database.MCPServerUserToken, + userID uuid.UUID, + oidcSrc UserOIDCTokenSource, ) ([]fantasy.AgentTool, *client.Client, error) { - headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID) + headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc) tr, err := createTransport(cfg, headers) if err != nil { @@ -285,6 +301,8 @@ func buildAuthHeaders( logger slog.Logger, cfg database.MCPServerConfig, tokensByConfigID map[uuid.UUID]database.MCPServerUserToken, + userID uuid.UUID, + oidcSrc UserOIDCTokenSource, ) map[string]string { // Using map[string]string rather than http.Header because // the mcp-go transport options accept map[string]string. @@ -347,6 +365,40 @@ func buildAuthHeaders( } } } + case "user_oidc": + // Forward the calling user's OIDC access token from + // user_links as Authorization: Bearer . The token + // source is responsible for refreshing tokens that are + // expired or close to expiring before returning them. + if oidcSrc == nil || userID == uuid.Nil { + logger.Warn(ctx, + "user_oidc auth requested but no token source available", + slog.F("server_slug", cfg.Slug), + ) + break + } + token, err := oidcSrc.OIDCAccessToken(ctx, userID) + if err != nil { + logger.Warn(ctx, + "failed to obtain user OIDC token for MCP server", + slog.F("server_slug", cfg.Slug), + slog.Error(err), + ) + break + } + if token == "" { + // The user has no OIDC link, or a non-fatal refresh + // failure occurred. Fall through with no header and let + // the upstream MCP server decide how to respond + // (typically 401). Logged at debug so password and + // GitHub users don't generate noise for every chat turn. + logger.Debug(ctx, + "no user OIDC token available for MCP server", + slog.F("server_slug", cfg.Slug), + ) + break + } + headers["Authorization"] = "Bearer " + token case "none", "": // No auth headers needed. } diff --git a/coderd/x/chatd/mcpclient/mcpclient_test.go b/coderd/x/chatd/mcpclient/mcpclient_test.go index 7676296360..dca1c5a1b8 100644 --- a/coderd/x/chatd/mcpclient/mcpclient_test.go +++ b/coderd/x/chatd/mcpclient/mcpclient_test.go @@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) { ts := newTestMCPServer(t, echoTool(), greetTool()) cfg := makeConfig("myserver", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) // Two tools should be discovered, namespaced with the server slug. @@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) { // Only allow the "echo" tool. cfg.ToolAllowList = []string{"echo"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) { // Deny the "greet" tool, so only "echo" remains. cfg.ToolDenyList = []string{"greet"} - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) { cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist") - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools, "no tools should be returned for an unreachable server") @@ -200,6 +200,7 @@ func TestConnectAll_MultipleServers(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -225,6 +226,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) { logger, []database.MCPServerConfig{cfg}, nil, + uuid.Nil, nil, ) require.Empty(t, tools) @@ -252,6 +254,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { makeConfig("srv2", ts2.URL), }, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -280,6 +283,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { makeConfig("aaa", other.URL), }, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -312,6 +316,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) { logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -375,6 +380,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -429,7 +435,7 @@ func TestConnectAll_DisabledServer(t *testing.T) { cfg := makeConfig("disabled", ts.URL) cfg.Enabled = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) assert.Empty(t, tools) } @@ -444,7 +450,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -469,7 +475,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) { ts := newTestMCPServer(t, echoTool()) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -511,7 +517,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) { ts := newTestMCPServer(t, noRequiredTool) cfg := makeConfig("srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -563,6 +569,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -619,6 +626,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -655,6 +663,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) { tools, cleanup := mcpclient.ConnectAll( ctx, logger, []database.MCPServerConfig{cfg}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -664,6 +673,158 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) { assert.Equal(t, "badjson__echo", tools[0].Info().Name) } +// staticOIDCSource implements mcpclient.UserOIDCTokenSource for tests +// without requiring a real OIDC provider or database round-trip. +type staticOIDCSource struct { + token string + err error +} + +func (s staticOIDCSource) OIDCAccessToken(_ context.Context, _ uuid.UUID) (string, error) { + return s.token, s.err +} + +// TestConnectAll_UserOIDCAuth verifies that the user_oidc auth type +// forwards the calling user's OIDC access token from the +// UserOIDCTokenSource as Authorization: Bearer . +func TestConnectAll_UserOIDCAuth(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + var ( + mu sync.Mutex + seenHeaders []string + ) + + srv := mcpserver.NewMCPServer("oidc-server", "1.0.0") + srv.AddTools(mcpserver.ServerTool{ + Tool: mcp.NewTool("whoami", + mcp.WithDescription("Returns the auth header"), + ), + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + auth := req.Header.Get("Authorization") + mu.Lock() + seenHeaders = append(seenHeaders, auth) + mu.Unlock() + return mcp.NewToolResultText("auth:" + auth), nil + }, + }) + + httpSrv := mcpserver.NewStreamableHTTPServer(srv) + ts := httptest.NewServer(httpSrv) + t.Cleanup(ts.Close) + + cfg := makeConfig("oidc-srv", ts.URL) + cfg.AuthType = "user_oidc" + userID := uuid.New() + src := staticOIDCSource{token: "fake-oidc-token"} + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, + userID, src, + ) + t.Cleanup(cleanup) + + require.Len(t, tools, 1) + + resp, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-oidc", + Name: "oidc-srv__whoami", + Input: "{}", + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + assert.Equal(t, "auth:Bearer fake-oidc-token", resp.Content) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, seenHeaders) + assert.Equal(t, "Bearer fake-oidc-token", seenHeaders[len(seenHeaders)-1]) +} + +// TestConnectAll_UserOIDCAuth_NoLink verifies that when the token +// source returns ("", nil) (the user has no OIDC link), the request +// is still made but with no Authorization header. The MCP server is +// then free to respond with 401 or proceed unauthenticated. +func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + var ( + mu sync.Mutex + seenHeaders []string + ) + + srv := mcpserver.NewMCPServer("oidc-server-nolink", "1.0.0") + srv.AddTools(mcpserver.ServerTool{ + Tool: mcp.NewTool("whoami", + mcp.WithDescription("Returns the auth header"), + ), + Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + auth := req.Header.Get("Authorization") + mu.Lock() + seenHeaders = append(seenHeaders, auth) + mu.Unlock() + return mcp.NewToolResultText("auth:" + auth), nil + }, + }) + + httpSrv := mcpserver.NewStreamableHTTPServer(srv) + ts := httptest.NewServer(httpSrv) + t.Cleanup(ts.Close) + + cfg := makeConfig("oidc-nolink", ts.URL) + cfg.AuthType = "user_oidc" + src := staticOIDCSource{token: "", err: nil} + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, + uuid.New(), src, + ) + t.Cleanup(cleanup) + + require.Len(t, tools, 1) + + resp, err := tools[0].Run(ctx, fantasy.ToolCall{ + ID: "call-oidc-nolink", + Name: "oidc-nolink__whoami", + Input: "{}", + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + assert.Equal(t, "auth:", resp.Content) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, seenHeaders) + assert.Empty(t, seenHeaders[len(seenHeaders)-1]) +} + +// TestConnectAll_UserOIDCAuth_NilSource verifies that a nil token +// source (e.g. deployment with no OIDC provider) yields no +// Authorization header rather than panicking. +func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) { + t.Parallel() + ctx := context.Background() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + ts := newTestMCPServer(t, echoTool()) + + cfg := makeConfig("oidc-nilsrc", ts.URL) + cfg.AuthType = "user_oidc" + + tools, cleanup := mcpclient.ConnectAll( + ctx, logger, []database.MCPServerConfig{cfg}, nil, + uuid.New(), nil, + ) + t.Cleanup(cleanup) + + require.Len(t, tools, 1) + assert.Equal(t, "oidc-nilsrc__echo", tools[0].Info().Name) +} + // TestConnectAll_ParallelConnections verifies that connecting to // multiple MCP servers simultaneously returns all discovered // tools with the correct server slug prefixes. @@ -684,6 +845,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2, cfg3}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -744,7 +906,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) { Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil) t.Cleanup(cleanup) // The server accepts any auth, so the tool is still discovered @@ -777,7 +939,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) { TokenType: "Bearer", } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil) t.Cleanup(cleanup) // Tool is still discovered (server doesn't require auth), but @@ -807,7 +969,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) { Enabled: true, } - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -853,6 +1015,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) { ctx, logger, []database.MCPServerConfig{cfg1, cfg2}, nil, + uuid.Nil, nil, ) t.Cleanup(cleanup) @@ -909,7 +1072,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-txt", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -976,7 +1139,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("embed-blob", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1056,7 +1219,7 @@ func TestConnectAll_ResourceLink(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("res-link", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1100,7 +1263,7 @@ func TestConnectAll_CallToolError(t *testing.T) { t.Cleanup(ts.Close) cfg := makeConfig("err-srv", ts.URL) - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1124,7 +1287,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) { cfg := makeConfig("intent-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1160,7 +1323,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) { cfg := makeConfig("no-intent", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1183,7 +1346,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) { cfg := makeConfig("unwrap-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1208,7 +1371,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) { cfg := makeConfig("flat-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1233,7 +1396,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) { cfg := makeConfig("pass-srv", ts.URL) cfg.ModelIntent = false - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) @@ -1258,7 +1421,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) { cfg := makeConfig("bad-srv", ts.URL) cfg.ModelIntent = true - tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil) + tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil) t.Cleanup(cleanup) require.Len(t, tools, 1) diff --git a/codersdk/mcp.go b/codersdk/mcp.go index 41b3adb220..132c804479 100644 --- a/codersdk/mcp.go +++ b/codersdk/mcp.go @@ -42,7 +42,7 @@ type MCPServerConfig struct { Transport string `json:"transport"` // "streamable_http" or "sse" URL string `json:"url"` - AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers" + AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers", "user_oidc" // OAuth2 fields (only populated for admins). OAuth2ClientID string `json:"oauth2_client_id,omitempty"` @@ -84,7 +84,7 @@ type CreateMCPServerConfigRequest struct { Transport string `json:"transport" validate:"required,oneof=streamable_http sse"` URL string `json:"url" validate:"required,url"` - AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers"` + AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers user_oidc"` OAuth2ClientID string `json:"oauth2_client_id,omitempty"` OAuth2ClientSecret string `json:"oauth2_client_secret,omitempty"` OAuth2AuthURL string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"` @@ -113,7 +113,7 @@ type UpdateMCPServerConfigRequest struct { Transport *string `json:"transport,omitempty" validate:"omitempty,oneof=streamable_http sse"` URL *string `json:"url,omitempty" validate:"omitempty,url"` - AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers"` + AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers user_oidc"` OAuth2ClientID *string `json:"oauth2_client_id,omitempty"` OAuth2ClientSecret *string `json:"oauth2_client_secret,omitempty"` OAuth2AuthURL *string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"` diff --git a/docs/ai-coder/agents/platform-controls/mcp-servers.md b/docs/ai-coder/agents/platform-controls/mcp-servers.md index 0a00e38f85..7deefcb6a9 100644 --- a/docs/ai-coder/agents/platform-controls/mcp-servers.md +++ b/docs/ai-coder/agents/platform-controls/mcp-servers.md @@ -48,7 +48,7 @@ This is an admin-only feature accessible at **Agents** > **Settings** > ## Authentication -Each MCP server uses one of four authentication modes. When you change the +Each MCP server uses one of five authentication modes. When you change the auth type, fields from the previous type are automatically cleared. Secrets are never returned in API responses — boolean flags indicate whether @@ -104,6 +104,21 @@ A static key sent as a header on every request. Arbitrary key-value header pairs sent on every request. At least one header is required when this mode is selected. +### User OIDC Identity + +Forwards the calling user's OIDC access token (stored in +`user_links.oauth_access_token`) to the MCP server as an +`Authorization: Bearer ` header. The token is refreshed +transparently before each request if it has expired or is close to +expiring. + +No admin-configurable fields. No per-user connect step. + +**Limitation**: this auth mode only works for users who authenticated to +Coder via OIDC. Users who logged in with password or GitHub will see +requests sent without an authorization header, and the upstream MCP +server is expected to respond with 401. + ## Tool governance Control which tools from a server are available in chat: diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 58de772320..7c276f78f9 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -4579,7 +4579,7 @@ export interface MCPServerConfig { readonly icon_url: string; readonly transport: string; // "streamable_http" or "sse" readonly url: string; - readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers" + readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers", "user_oidc" /** * OAuth2 fields (only populated for admins). */ diff --git a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx index 7611cb4221..58a605bb2b 100644 --- a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.stories.tsx @@ -698,3 +698,55 @@ export const CustomHeadersAuthType: Story = { ); }, }; + +export const CreateServerUserOIDC: Story = { + args: { + serversData: [], + }, + play: async ({ canvasElement, args }) => { + const body = within(canvasElement.ownerDocument.body); + + await userEvent.click( + await body.findByRole("button", { name: /Add your first server/i }), + ); + + await userEvent.type( + await body.findByLabelText(/Display Name/i), + "Internal API", + ); + await userEvent.type( + body.getByLabelText(/Server URL/i), + "https://mcp.internal.example.com/v1", + ); + + await userEvent.click( + await body.findByRole("button", { name: /Authentication/i }), + ); + await userEvent.click(body.getByLabelText(/Authentication/i)); + await userEvent.click( + await body.findByRole("option", { name: /User OIDC Identity/i }), + ); + + // No additional auth fields for user_oidc; the helper text is shown. + expect( + body.getByText(/forwarded to this MCP server in the/i), + ).toBeInTheDocument(); + + await userEvent.click(body.getByRole("button", { name: /Create server/i })); + + await waitFor(() => { + expect(args.onCreateServer).toHaveBeenCalledTimes(1); + }); + expect(args.onCreateServer).toHaveBeenCalledWith( + expect.objectContaining({ + auth_type: "user_oidc", + }), + ); + // Should not include any oauth2/api_key/custom_headers fields. + const call = (args.onCreateServer as ReturnType).mock + .calls[0][0]; + expect(call).not.toHaveProperty("oauth2_client_id"); + expect(call).not.toHaveProperty("api_key_value"); + expect(call).not.toHaveProperty("custom_headers"); + }, +}; diff --git a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx index 2fb7e64f9f..a45ea6aa93 100644 --- a/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx +++ b/site/src/pages/AgentsPage/components/MCPServerAdminPanel.tsx @@ -69,6 +69,7 @@ const AUTH_TYPE_OPTIONS = [ { value: "oauth2", label: "OAuth2" }, { value: "api_key", label: "API Key" }, { value: "custom_headers", label: "Custom Headers" }, + { value: "user_oidc", label: "User OIDC Identity" }, ] as const; const AVAILABILITY_OPTIONS = [ @@ -900,6 +901,22 @@ const ServerForm: FC = ({ )} + + {form.values.authType === "user_oidc" && ( +
+

+ The calling user's OIDC access token is forwarded to this + MCP server in the Authorization header. + Tokens are refreshed transparently before each request. +

+

+ Users who did not log in via OIDC (for example, password + or GitHub login) will see requests sent without an + authorization header. Configure no other fields for this + auth type. +

+
+ )}