mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add user_oidc auth type for MCP servers (#24793)
Adds a 5th MCP server authentication mode, `user_oidc` ("User OIDC
Identity"), that forwards the calling user's OIDC access token from
`user_links.oauth_access_token` to the upstream MCP server as
`Authorization: Bearer <token>`.
The token is read from `user_links` and refreshed transparently via
`oauth2.TokenSource` before each MCP request. No new per-MCP-server
secret storage and no per-user connect/disconnect step.
**Limitation**: only users who logged in via OIDC have a forwardable
token. Users authenticated via password or GitHub will see requests sent
without an `Authorization` header, and the upstream MCP server is
expected to respond with 401. A pluggable token source (e.g. CLI-minted
E2E tokens) is left as future work.
<details>
<summary>Implementation notes</summary>
- Schema: new
`coderd/database/migrations/000481_mcp_user_oidc_auth.{up,down}.sql`
relaxes the `mcp_server_configs.auth_type` CHECK constraint to include
`user_oidc`. Down migration deletes affected rows before restoring the
old constraint.
- SDK validation: `codersdk/mcp.go` extends `oneof` for
`CreateMCPServerConfigRequest` and `UpdateMCPServerConfigRequest`.
- Handler: `coderd/mcp.go` adds `case "user_oidc":` to the
field-clearing switch on update. The existing list and detail handlers
already report `auth_connected = true` for any non-`oauth2` auth type.
- Header construction: `coderd/x/chatd/mcpclient/mcpclient.go`
introduces a `UserOIDCTokenSource` interface and adds the `user_oidc`
case to `buildAuthHeaders`. `ConnectAll` / `connectOne` /
`buildAuthHeaders` gain `userID uuid.UUID, oidcSrc UserOIDCTokenSource`
parameters.
- Wiring: `coderd/x/chatd/chatd.go` adds `OIDCTokenSource` to `Config` /
`Server` and passes `chat.OwnerID` plus the source through `ConnectAll`.
`coderd/coderd.go` constructs the source next to the `chatd.New` call
when `options.OIDCConfig` is non-nil.
- Token source: `oidcMCPTokenSource` lives in `coderd/mcp.go`. It reads
the user's OIDC link, refreshes via `oauth2.TokenSource`, and writes the
refreshed token back to `user_links`. Logic is duplicated from
`provisionerdserver.ObtainOIDCAccessToken` to avoid an MCP ->
provisionerdserver dependency. The two copies must be kept in sync; a
comment on `oidcMCPTokenSource` records this.
- Frontend: `MCPServerAdminPanel.tsx` adds the new dropdown option, an
explanatory helper block (no admin-configurable fields), and a Storybook
story (`CreateServerUserOIDC`).
- Tests:
- `mcpclient_test.go`: `TestConnectAll_UserOIDCAuth`,
`TestConnectAll_UserOIDCAuth_NoLink`,
`TestConnectAll_UserOIDCAuth_NilSource`. All existing tests updated for
the new signature.
- `mcp_test.go`: extends `TestMCPServerConfigsAuthConnected` to assert
`auth_connected=true` for `user_oidc`; adds
`TestMCPServerConfigsUserOIDCClearsFields` and
`TestMCPServerConfigsUserOIDCDirect`.
- Docs: `docs/ai-coder/agents/platform-controls/mcp-servers.md`
describes the new mode and its OIDC-only limitation.
</details>
This PR was created by Coder Agents.
---------
Co-authored-by: Coder Agents <agents@coder.com>
This commit is contained in:
@@ -93,6 +93,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
@@ -777,6 +778,14 @@ func New(options *Options) *API {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
var oidcMCPSrc mcpclient.UserOIDCTokenSource
|
||||
if options.OIDCConfig != nil {
|
||||
oidcMCPSrc = newOIDCMCPTokenSource(
|
||||
options.Database,
|
||||
options.OIDCConfig,
|
||||
options.Logger.Named("mcp-user-oidc"),
|
||||
)
|
||||
}
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
@@ -794,6 +803,7 @@ func New(options *Options) *API {
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
PrometheusRegistry: options.PrometheusRegistry,
|
||||
OIDCTokenSource: oidcMCPSrc,
|
||||
}).Start()
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
|
||||
Generated
+1
-1
@@ -1758,7 +1758,7 @@ CREATE TABLE mcp_server_configs (
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
model_intent boolean DEFAULT false NOT NULL,
|
||||
allow_in_plan_mode boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
|
||||
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text, 'user_oidc'::text]))),
|
||||
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
|
||||
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
|
||||
);
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
-- Rolling this migration back deletes any rows using the user_oidc auth
|
||||
-- type because they would otherwise violate the restored CHECK constraint.
|
||||
DELETE FROM mcp_server_configs WHERE auth_type = 'user_oidc';
|
||||
|
||||
ALTER TABLE mcp_server_configs
|
||||
DROP CONSTRAINT mcp_server_configs_auth_type_check;
|
||||
|
||||
ALTER TABLE mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_auth_type_check
|
||||
CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers'));
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE mcp_server_configs
|
||||
DROP CONSTRAINT mcp_server_configs_auth_type_check;
|
||||
|
||||
ALTER TABLE mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_auth_type_check
|
||||
CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers', 'user_oidc'));
|
||||
+126
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -21,14 +22,124 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// oidcMCPTokenSource implements mcpclient.UserOIDCTokenSource using
|
||||
// the same refresh strategy as provisionerdserver.ObtainOIDCAccessToken.
|
||||
// The logic is duplicated to avoid importing provisionerdserver from
|
||||
// coderd; keep the two in sync.
|
||||
type oidcMCPTokenSource struct {
|
||||
db database.Store
|
||||
config promoauth.OAuth2Config
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
// newOIDCMCPTokenSource returns nil when no OIDC provider is
|
||||
// configured. mcpclient treats a nil source the same as "no token
|
||||
// available" and omits the Authorization header.
|
||||
func newOIDCMCPTokenSource(db database.Store, config promoauth.OAuth2Config, logger slog.Logger) mcpclient.UserOIDCTokenSource {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
return &oidcMCPTokenSource{
|
||||
db: db,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// OIDCAccessToken implements mcpclient.UserOIDCTokenSource. It
|
||||
// refreshes expired tokens and persists the refreshed token back
|
||||
// to user_links. The chatd dbauthz subject does not grant
|
||||
// ResourceSystem.Read or ResourceUser.UpdatePersonal, so DB calls
|
||||
// elevate to AsSystemRestricted; the per-user authorization is
|
||||
// already enforced by the API handler that owns ctx.
|
||||
func (s *oidcMCPTokenSource) OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error) {
|
||||
//nolint:gocritic // user_links read needs system access; the
|
||||
// caller's user identity is supplied via the userID parameter.
|
||||
dbCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
link, err := s.db.GetUserLinkByUserIDLoginType(dbCtx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: userID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get oidc user link: %w", err)
|
||||
}
|
||||
|
||||
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
|
||||
token, err := s.config.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
// Use the expiresAt returned by shouldRefreshOIDCToken.
|
||||
// It will force a refresh with an expired time.
|
||||
Expiry: expiresAt,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
// Don't fail the request; the upstream MCP server will see no
|
||||
// Authorization header and can return a 401 if it requires one.
|
||||
s.logger.Warn(ctx, "failed to refresh OIDC token for MCP request",
|
||||
slog.F("user_id", userID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return "", nil
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
|
||||
// Persist on a detached context so a canceled chat request
|
||||
// cannot drop a refresh-token rotation, see PR #24332.
|
||||
persistCtx, persistCancel := context.WithTimeout(
|
||||
context.WithoutCancel(dbCtx), 10*time.Second,
|
||||
)
|
||||
link, err = s.db.UpdateUserLink(persistCtx, database.UpdateUserLinkParams{
|
||||
UserID: userID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
Claims: link.Claims,
|
||||
})
|
||||
persistCancel()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("update user link after oidc refresh: %w", err)
|
||||
}
|
||||
s.logger.Info(ctx, "refreshed expired OIDC token for MCP request",
|
||||
slog.F("user_id", userID),
|
||||
)
|
||||
}
|
||||
|
||||
return link.OAuthAccessToken, nil
|
||||
}
|
||||
|
||||
// shouldRefreshOIDCToken mirrors provisionerdserver.shouldRefreshOIDCToken.
|
||||
// See that function for the rationale behind the 10-minute pre-expiry
|
||||
// buffer.
|
||||
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return false, link.OAuthExpiry
|
||||
}
|
||||
if link.OAuthExpiry.IsZero() {
|
||||
// A zero expiry means the token never expires.
|
||||
return false, link.OAuthExpiry
|
||||
}
|
||||
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
|
||||
return expiresAt.Before(dbtime.Now()), expiresAt
|
||||
}
|
||||
|
||||
// @Summary List MCP server configs
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -629,6 +740,21 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKeyHeader = ""
|
||||
apiKeyValue = ""
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
case "user_oidc":
|
||||
// user_oidc forwards the calling user's OIDC access token
|
||||
// from user_links at request time, so no admin-configured
|
||||
// secrets are stored on the row.
|
||||
oauth2ClientID = ""
|
||||
oauth2ClientSecret = ""
|
||||
oauth2ClientSecretKeyID = sql.NullString{}
|
||||
oauth2AuthURL = ""
|
||||
oauth2TokenURL = ""
|
||||
oauth2Scopes = ""
|
||||
apiKeyHeader = ""
|
||||
apiKeyValue = ""
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
customHeaders = "{}"
|
||||
customHeadersKeyID = sql.NullString{}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// dbauthzTestStore wraps the test database with the same dbauthz layer
|
||||
// used in production (coderd.go:370). Without it the test would not
|
||||
// catch RBAC failures from the chatd subject; with it the test fails
|
||||
// loudly if the elevation in OIDCAccessToken is removed or weakened.
|
||||
func dbauthzTestStore(t *testing.T, db database.Store) database.Store {
|
||||
t.Helper()
|
||||
|
||||
authz := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
acs := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var tacs dbauthz.AccessControlStore = fakeAccessControlStore{}
|
||||
acs.Store(&tacs)
|
||||
return dbauthz.New(db, authz, testutil.Logger(t), acs)
|
||||
}
|
||||
|
||||
// fakeAccessControlStore mirrors coderdtest.FakeAccessControlStore but is
|
||||
// inlined here to avoid an import cycle (coderdtest imports coderd).
|
||||
type fakeAccessControlStore struct{}
|
||||
|
||||
func (fakeAccessControlStore) GetTemplateAccessControl(t database.Template) dbauthz.TemplateAccessControl {
|
||||
return dbauthz.TemplateAccessControl{
|
||||
RequireActiveVersion: t.RequireActiveVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (fakeAccessControlStore) SetTemplateAccessControl(context.Context, database.Store, uuid.UUID, dbauthz.TemplateAccessControl) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func TestShouldRefreshOIDCToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
cases := []struct {
|
||||
name string
|
||||
link database.UserLink
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "NoRefreshToken",
|
||||
link: database.UserLink{OAuthExpiry: now.Add(-time.Hour)},
|
||||
},
|
||||
{
|
||||
name: "ZeroExpiry",
|
||||
link: database.UserLink{OAuthRefreshToken: "refresh"},
|
||||
},
|
||||
{
|
||||
name: "Expired",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(-time.Hour),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Fresh",
|
||||
link: database.UserLink{
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: now.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, _ := shouldRefreshOIDCToken(tc.link)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCMCPTokenSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
t.Run("NilConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
require.Nil(t, newOIDCMCPTokenSource(db, nil, logger))
|
||||
})
|
||||
|
||||
t.Run("NoLink", func(t *testing.T) {
|
||||
// When the user has no OIDC link the source returns ("", nil)
|
||||
// rather than an error so the caller can fall through to
|
||||
// "no Authorization header".
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
store := dbauthzTestStore(t, db)
|
||||
user := dbgen.User(t, db, database.User{LoginType: database.LoginTypeOIDC})
|
||||
|
||||
src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{}, logger)
|
||||
ctx := dbauthz.AsChatd(context.Background())
|
||||
|
||||
tok, err := src.OIDCAccessToken(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tok)
|
||||
})
|
||||
|
||||
t.Run("FreshToken", func(t *testing.T) {
|
||||
// A non-expired token is returned as-is; no refresh is performed.
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
store := dbauthzTestStore(t, db)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: "fresh",
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: dbtime.Now().Add(time.Hour),
|
||||
})
|
||||
|
||||
src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{AccessToken: "should-not-be-used"},
|
||||
}, logger)
|
||||
ctx := dbauthz.AsChatd(context.Background())
|
||||
|
||||
tok, err := src.OIDCAccessToken(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fresh", tok)
|
||||
})
|
||||
|
||||
t.Run("RefreshExpired", func(t *testing.T) {
|
||||
// An expired token triggers a refresh; the new token is
|
||||
// persisted via UpdateUserLink. This exercises the dbauthz
|
||||
// elevation: chatd lacks ResourceSystem.Read and
|
||||
// ResourceUser.UpdatePersonal so a non-elevated context
|
||||
// would fail both reads and writes.
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
store := dbauthzTestStore(t, db)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: "stale",
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: dbtime.Now().Add(-time.Hour),
|
||||
})
|
||||
|
||||
src := newOIDCMCPTokenSource(store, &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "fresh",
|
||||
RefreshToken: "new-refresh",
|
||||
Expiry: dbtime.Now().Add(time.Hour),
|
||||
},
|
||||
}, logger)
|
||||
ctx := dbauthz.AsChatd(context.Background())
|
||||
|
||||
tok, err := src.OIDCAccessToken(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fresh", tok)
|
||||
|
||||
// Verify the refresh was persisted via UpdateUserLink.
|
||||
got, err := db.GetUserLinkByUserIDLoginType(
|
||||
dbauthz.AsSystemRestricted(context.Background()),
|
||||
database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fresh", got.OAuthAccessToken)
|
||||
require.Equal(t, "new-refresh", got.OAuthRefreshToken)
|
||||
})
|
||||
|
||||
t.Run("RefreshFailureReturnsEmpty", func(t *testing.T) {
|
||||
// A refresh attempt that fails (e.g. invalid client config)
|
||||
// must not surface an error to the caller; per the
|
||||
// UserOIDCTokenSource contract this is treated as "no
|
||||
// Authorization header".
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
store := dbauthzTestStore(t, db)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: "stale",
|
||||
OAuthRefreshToken: "refresh",
|
||||
OAuthExpiry: dbtime.Now().Add(-time.Hour),
|
||||
})
|
||||
|
||||
// An empty oauth2.Config triggers a refresh failure
|
||||
// because it has no token endpoint to call.
|
||||
src := newOIDCMCPTokenSource(store, &oauth2.Config{}, logger)
|
||||
ctx := dbauthz.AsChatd(context.Background())
|
||||
|
||||
tok, err := src.OIDCAccessToken(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tok)
|
||||
})
|
||||
}
|
||||
+96
-5
@@ -316,18 +316,109 @@ func TestMCPServerConfigsAuthConnected(t *testing.T) {
|
||||
// Also create a non-oauth server. It should report
|
||||
// auth_connected=true because no auth is needed.
|
||||
_ = createMCPServerConfig(t, adminClient, "no-auth-server", true)
|
||||
|
||||
// And a user_oidc server. user_oidc never requires a per-user
|
||||
// connect step, so auth_connected is always true regardless of
|
||||
// whether the calling user has an OIDC link.
|
||||
_, err = adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "User OIDC Server",
|
||||
Slug: "user-oidc-server",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/oidc",
|
||||
AuthType: "user_oidc",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memberConfigs, err = memberClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberConfigs, 2)
|
||||
require.Len(t, memberConfigs, 3)
|
||||
for _, cfg := range memberConfigs {
|
||||
if cfg.AuthType == "none" {
|
||||
require.True(t, cfg.AuthConnected)
|
||||
} else {
|
||||
require.False(t, cfg.AuthConnected)
|
||||
switch cfg.AuthType {
|
||||
case "none", "user_oidc":
|
||||
require.True(t, cfg.AuthConnected, "%s should report auth_connected", cfg.AuthType)
|
||||
default:
|
||||
require.False(t, cfg.AuthConnected, "%s should not report auth_connected", cfg.AuthType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsUserOIDCClearsFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Start with an oauth2 config that has a client secret, then
|
||||
// switch the auth_type to user_oidc and verify all auth-specific
|
||||
// fields are cleared.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Switch Server",
|
||||
Slug: "switch-server",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/v1",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "cid",
|
||||
OAuth2ClientSecret: "secret-value",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
OAuth2Scopes: "read write",
|
||||
Availability: "default_off",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, created.HasOAuth2Secret)
|
||||
require.Equal(t, "cid", created.OAuth2ClientID)
|
||||
|
||||
newAuth := "user_oidc"
|
||||
updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{
|
||||
AuthType: &newAuth,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user_oidc", updated.AuthType)
|
||||
require.False(t, updated.HasOAuth2Secret, "oauth2 secret should be cleared")
|
||||
require.False(t, updated.HasAPIKey, "api key should remain unset")
|
||||
require.False(t, updated.HasCustomHeaders, "custom headers should remain unset")
|
||||
require.Empty(t, updated.OAuth2ClientID)
|
||||
require.Empty(t, updated.OAuth2AuthURL)
|
||||
require.Empty(t, updated.OAuth2TokenURL)
|
||||
require.Empty(t, updated.OAuth2Scopes)
|
||||
require.Empty(t, updated.APIKeyHeader)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsUserOIDCDirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create with user_oidc and confirm validation accepts the value
|
||||
// while no auth-specific fields are persisted on the row.
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "User OIDC Direct",
|
||||
Slug: "user-oidc-direct",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/oidc-direct",
|
||||
AuthType: "user_oidc",
|
||||
Availability: "default_off",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user_oidc", created.AuthType)
|
||||
require.False(t, created.HasOAuth2Secret)
|
||||
require.False(t, created.HasAPIKey)
|
||||
require.False(t, created.HasCustomHeaders)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsAvailability(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -162,6 +162,7 @@ type Server struct {
|
||||
pubsub pubsub.Pubsub
|
||||
webpushDispatcher webpush.Dispatcher
|
||||
providerAPIKeys chatprovider.ProviderAPIKeys
|
||||
oidcTokenSource mcpclient.UserOIDCTokenSource
|
||||
debugSvc *chatdebug.Service
|
||||
debugSvcFactory func() *chatdebug.Service
|
||||
debugSvcReady atomic.Bool
|
||||
@@ -3840,6 +3841,12 @@ type Config struct {
|
||||
UsageTracker *workspacestats.UsageTracker
|
||||
Clock quartz.Clock
|
||||
PrometheusRegistry prometheus.Registerer
|
||||
|
||||
// OIDCTokenSource resolves the calling user's OIDC access
|
||||
// token for MCP servers configured with auth_type=user_oidc.
|
||||
// May be nil if the deployment has no OIDC provider; servers
|
||||
// using user_oidc will then send no Authorization header.
|
||||
OIDCTokenSource mcpclient.UserOIDCTokenSource
|
||||
}
|
||||
|
||||
// New creates a new chat processor. The processor polls for pending
|
||||
@@ -3898,6 +3905,7 @@ func New(cfg Config) *Server {
|
||||
pubsub: cfg.Pubsub,
|
||||
webpushDispatcher: cfg.WebpushDispatcher,
|
||||
providerAPIKeys: cfg.ProviderAPIKeys,
|
||||
oidcTokenSource: cfg.OIDCTokenSource,
|
||||
debugSvcFactory: func() *chatdebug.Service {
|
||||
debugSvc := chatdebug.NewService(
|
||||
cfg.Database,
|
||||
@@ -6472,7 +6480,7 @@ func (p *Server) runChat(
|
||||
// Refresh expired OAuth2 tokens before connecting.
|
||||
mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens)
|
||||
mcpTools, mcpCleanup = mcpclient.ConnectAll(
|
||||
ctx, logger, mcpConnectConfigs, mcpTokens,
|
||||
ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource,
|
||||
)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -49,6 +49,18 @@ const connectTimeout = 10 * time.Second
|
||||
// take before being canceled.
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
// UserOIDCTokenSource resolves the OIDC access token for the calling
|
||||
// user. Implementations attempt to refresh tokens that are expired
|
||||
// or close to expiring and MUST return ("", nil) when the user has
|
||||
// no OIDC link or a refresh attempt failed for any reason. A
|
||||
// non-nil error is reserved for unexpected infrastructure failures
|
||||
// (e.g. database errors) and skips header construction entirely.
|
||||
// The empty-token-on-refresh-failure behavior matches
|
||||
// provisionerdserver.ObtainOIDCAccessToken.
|
||||
type UserOIDCTokenSource interface {
|
||||
OIDCAccessToken(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
}
|
||||
|
||||
// ConnectAll connects to all configured MCP servers, discovers
|
||||
// their tools, and returns them as fantasy.AgentTool values.
|
||||
// Tools are sorted by their prefixed name so callers
|
||||
@@ -60,6 +72,8 @@ func ConnectAll(
|
||||
logger slog.Logger,
|
||||
configs []database.MCPServerConfig,
|
||||
tokens []database.MCPServerUserToken,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
) ([]fantasy.AgentTool, func()) {
|
||||
// Index tokens by server config ID so auth header
|
||||
// construction is O(1) per server.
|
||||
@@ -95,7 +109,7 @@ func ConnectAll(
|
||||
|
||||
eg.Go(func() error {
|
||||
serverTools, mcpClient, connectErr := connectOne(
|
||||
ctx, logger, cfg, tokensByConfigID,
|
||||
ctx, logger, cfg, tokensByConfigID, userID, oidcSrc,
|
||||
)
|
||||
if connectErr != nil {
|
||||
logger.Warn(ctx,
|
||||
@@ -159,8 +173,10 @@ func connectOne(
|
||||
logger slog.Logger,
|
||||
cfg database.MCPServerConfig,
|
||||
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
) ([]fantasy.AgentTool, *client.Client, error) {
|
||||
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID)
|
||||
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID, userID, oidcSrc)
|
||||
|
||||
tr, err := createTransport(cfg, headers)
|
||||
if err != nil {
|
||||
@@ -285,6 +301,8 @@ func buildAuthHeaders(
|
||||
logger slog.Logger,
|
||||
cfg database.MCPServerConfig,
|
||||
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
|
||||
userID uuid.UUID,
|
||||
oidcSrc UserOIDCTokenSource,
|
||||
) map[string]string {
|
||||
// Using map[string]string rather than http.Header because
|
||||
// the mcp-go transport options accept map[string]string.
|
||||
@@ -347,6 +365,40 @@ func buildAuthHeaders(
|
||||
}
|
||||
}
|
||||
}
|
||||
case "user_oidc":
|
||||
// Forward the calling user's OIDC access token from
|
||||
// user_links as Authorization: Bearer <token>. The token
|
||||
// source is responsible for refreshing tokens that are
|
||||
// expired or close to expiring before returning them.
|
||||
if oidcSrc == nil || userID == uuid.Nil {
|
||||
logger.Warn(ctx,
|
||||
"user_oidc auth requested but no token source available",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
)
|
||||
break
|
||||
}
|
||||
token, err := oidcSrc.OIDCAccessToken(ctx, userID)
|
||||
if err != nil {
|
||||
logger.Warn(ctx,
|
||||
"failed to obtain user OIDC token for MCP server",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
slog.Error(err),
|
||||
)
|
||||
break
|
||||
}
|
||||
if token == "" {
|
||||
// The user has no OIDC link, or a non-fatal refresh
|
||||
// failure occurred. Fall through with no header and let
|
||||
// the upstream MCP server decide how to respond
|
||||
// (typically 401). Logged at debug so password and
|
||||
// GitHub users don't generate noise for every chat turn.
|
||||
logger.Debug(ctx,
|
||||
"no user OIDC token available for MCP server",
|
||||
slog.F("server_slug", cfg.Slug),
|
||||
)
|
||||
break
|
||||
}
|
||||
headers["Authorization"] = "Bearer " + token
|
||||
case "none", "":
|
||||
// No auth headers needed.
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestConnectAll_DiscoverTools(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool(), greetTool())
|
||||
|
||||
cfg := makeConfig("myserver", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Two tools should be discovered, namespaced with the server slug.
|
||||
@@ -121,7 +121,7 @@ func TestConnectAll_CallTool(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -147,7 +147,7 @@ func TestConnectAll_ToolAllowList(t *testing.T) {
|
||||
// Only allow the "echo" tool.
|
||||
cfg.ToolAllowList = []string{"echo"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -165,7 +165,7 @@ func TestConnectAll_ToolDenyList(t *testing.T) {
|
||||
// Deny the "greet" tool, so only "echo" remains.
|
||||
cfg.ToolDenyList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -179,7 +179,7 @@ func TestConnectAll_ConnectionFailure(t *testing.T) {
|
||||
|
||||
cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
assert.Empty(t, tools, "no tools should be returned for an unreachable server")
|
||||
@@ -200,6 +200,7 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -225,6 +226,7 @@ func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
|
||||
require.Empty(t, tools)
|
||||
@@ -252,6 +254,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -280,6 +283,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -312,6 +316,7 @@ func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -375,6 +380,7 @@ func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
[]database.MCPServerUserToken{token},
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -429,7 +435,7 @@ func TestConnectAll_DisabledServer(t *testing.T) {
|
||||
cfg := makeConfig("disabled", ts.URL)
|
||||
cfg.Enabled = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
assert.Empty(t, tools)
|
||||
}
|
||||
@@ -444,7 +450,7 @@ func TestConnectAll_CallToolInvalidInput(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -469,7 +475,7 @@ func TestConnectAll_ToolInfoParameters(t *testing.T) {
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -511,7 +517,7 @@ func TestConnectAll_NilRequiredBecomesEmptySlice(t *testing.T) {
|
||||
|
||||
ts := newTestMCPServer(t, noRequiredTool)
|
||||
cfg := makeConfig("srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -563,6 +569,7 @@ func TestConnectAll_APIKeyAuth(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -619,6 +626,7 @@ func TestConnectAll_CustomHeadersAuth(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -655,6 +663,7 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -664,6 +673,158 @@ func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
|
||||
assert.Equal(t, "badjson__echo", tools[0].Info().Name)
|
||||
}
|
||||
|
||||
// staticOIDCSource implements mcpclient.UserOIDCTokenSource for tests
|
||||
// without requiring a real OIDC provider or database round-trip.
|
||||
type staticOIDCSource struct {
|
||||
token string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s staticOIDCSource) OIDCAccessToken(_ context.Context, _ uuid.UUID) (string, error) {
|
||||
return s.token, s.err
|
||||
}
|
||||
|
||||
// TestConnectAll_UserOIDCAuth verifies that the user_oidc auth type
|
||||
// forwards the calling user's OIDC access token from the
|
||||
// UserOIDCTokenSource as Authorization: Bearer <token>.
|
||||
func TestConnectAll_UserOIDCAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seenHeaders []string
|
||||
)
|
||||
|
||||
srv := mcpserver.NewMCPServer("oidc-server", "1.0.0")
|
||||
srv.AddTools(mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool("whoami",
|
||||
mcp.WithDescription("Returns the auth header"),
|
||||
),
|
||||
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
auth := req.Header.Get("Authorization")
|
||||
mu.Lock()
|
||||
seenHeaders = append(seenHeaders, auth)
|
||||
mu.Unlock()
|
||||
return mcp.NewToolResultText("auth:" + auth), nil
|
||||
},
|
||||
})
|
||||
|
||||
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
|
||||
ts := httptest.NewServer(httpSrv)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("oidc-srv", ts.URL)
|
||||
cfg.AuthType = "user_oidc"
|
||||
userID := uuid.New()
|
||||
src := staticOIDCSource{token: "fake-oidc-token"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
userID, src,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-oidc",
|
||||
Name: "oidc-srv__whoami",
|
||||
Input: "{}",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
assert.Equal(t, "auth:Bearer fake-oidc-token", resp.Content)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.NotEmpty(t, seenHeaders)
|
||||
assert.Equal(t, "Bearer fake-oidc-token", seenHeaders[len(seenHeaders)-1])
|
||||
}
|
||||
|
||||
// TestConnectAll_UserOIDCAuth_NoLink verifies that when the token
|
||||
// source returns ("", nil) (the user has no OIDC link), the request
|
||||
// is still made but with no Authorization header. The MCP server is
|
||||
// then free to respond with 401 or proceed unauthenticated.
|
||||
func TestConnectAll_UserOIDCAuth_NoLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seenHeaders []string
|
||||
)
|
||||
|
||||
srv := mcpserver.NewMCPServer("oidc-server-nolink", "1.0.0")
|
||||
srv.AddTools(mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool("whoami",
|
||||
mcp.WithDescription("Returns the auth header"),
|
||||
),
|
||||
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
auth := req.Header.Get("Authorization")
|
||||
mu.Lock()
|
||||
seenHeaders = append(seenHeaders, auth)
|
||||
mu.Unlock()
|
||||
return mcp.NewToolResultText("auth:" + auth), nil
|
||||
},
|
||||
})
|
||||
|
||||
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
|
||||
ts := httptest.NewServer(httpSrv)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("oidc-nolink", ts.URL)
|
||||
cfg.AuthType = "user_oidc"
|
||||
src := staticOIDCSource{token: "", err: nil}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
uuid.New(), src,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-oidc-nolink",
|
||||
Name: "oidc-nolink__whoami",
|
||||
Input: "{}",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
assert.Equal(t, "auth:", resp.Content)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.NotEmpty(t, seenHeaders)
|
||||
assert.Empty(t, seenHeaders[len(seenHeaders)-1])
|
||||
}
|
||||
|
||||
// TestConnectAll_UserOIDCAuth_NilSource verifies that a nil token
|
||||
// source (e.g. deployment with no OIDC provider) yields no
|
||||
// Authorization header rather than panicking.
|
||||
func TestConnectAll_UserOIDCAuth_NilSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("oidc-nilsrc", ts.URL)
|
||||
cfg.AuthType = "user_oidc"
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx, logger, []database.MCPServerConfig{cfg}, nil,
|
||||
uuid.New(), nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
assert.Equal(t, "oidc-nilsrc__echo", tools[0].Info().Name)
|
||||
}
|
||||
|
||||
// TestConnectAll_ParallelConnections verifies that connecting to
|
||||
// multiple MCP servers simultaneously returns all discovered
|
||||
// tools with the correct server slug prefixes.
|
||||
@@ -684,6 +845,7 @@ func TestConnectAll_ParallelConnections(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2, cfg3},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -744,7 +906,7 @@ func TestConnectAll_ExpiredToken(t *testing.T) {
|
||||
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token})
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// The server accepts any auth, so the tool is still discovered
|
||||
@@ -777,7 +939,7 @@ func TestConnectAll_EmptyAccessToken(t *testing.T) {
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token})
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token}, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Tool is still discovered (server doesn't require auth), but
|
||||
@@ -807,7 +969,7 @@ func TestConnectAll_MCPToolIdentifier(t *testing.T) {
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 1)
|
||||
@@ -853,6 +1015,7 @@ func TestConnectAll_MCPToolIdentifier_MultipleServers(t *testing.T) {
|
||||
ctx, logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
uuid.Nil, nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -909,7 +1072,7 @@ func TestConnectAll_EmbeddedResourceText(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("embed-txt", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -976,7 +1139,7 @@ func TestConnectAll_EmbeddedResourceBlob(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("embed-blob", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1056,7 +1219,7 @@ func TestConnectAll_ResourceLink(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("res-link", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1100,7 +1263,7 @@ func TestConnectAll_CallToolError(t *testing.T) {
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
cfg := makeConfig("err-srv", ts.URL)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1124,7 +1287,7 @@ func TestModelIntent_Info_WrapsSchema(t *testing.T) {
|
||||
cfg := makeConfig("intent-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1160,7 +1323,7 @@ func TestModelIntent_Info_NoWrapWhenDisabled(t *testing.T) {
|
||||
cfg := makeConfig("no-intent", ts.URL)
|
||||
cfg.ModelIntent = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1183,7 +1346,7 @@ func TestModelIntent_Run_UnwrapsProperties(t *testing.T) {
|
||||
cfg := makeConfig("unwrap-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1208,7 +1371,7 @@ func TestModelIntent_Run_UnwrapsFlat(t *testing.T) {
|
||||
cfg := makeConfig("flat-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1233,7 +1396,7 @@ func TestModelIntent_Run_PassthroughWhenDisabled(t *testing.T) {
|
||||
cfg := makeConfig("pass-srv", ts.URL)
|
||||
cfg.ModelIntent = false
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
@@ -1258,7 +1421,7 @@ func TestModelIntent_Run_FallbackOnBadJSON(t *testing.T) {
|
||||
cfg := makeConfig("bad-srv", ts.URL)
|
||||
cfg.ModelIntent = true
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
|
||||
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil, uuid.Nil, nil)
|
||||
t.Cleanup(cleanup)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
|
||||
+3
-3
@@ -42,7 +42,7 @@ type MCPServerConfig struct {
|
||||
Transport string `json:"transport"` // "streamable_http" or "sse"
|
||||
URL string `json:"url"`
|
||||
|
||||
AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers"
|
||||
AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers", "user_oidc"
|
||||
|
||||
// OAuth2 fields (only populated for admins).
|
||||
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
|
||||
@@ -84,7 +84,7 @@ type CreateMCPServerConfigRequest struct {
|
||||
Transport string `json:"transport" validate:"required,oneof=streamable_http sse"`
|
||||
URL string `json:"url" validate:"required,url"`
|
||||
|
||||
AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers"`
|
||||
AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers user_oidc"`
|
||||
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
|
||||
OAuth2ClientSecret string `json:"oauth2_client_secret,omitempty"`
|
||||
OAuth2AuthURL string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
|
||||
@@ -113,7 +113,7 @@ type UpdateMCPServerConfigRequest struct {
|
||||
Transport *string `json:"transport,omitempty" validate:"omitempty,oneof=streamable_http sse"`
|
||||
URL *string `json:"url,omitempty" validate:"omitempty,url"`
|
||||
|
||||
AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers"`
|
||||
AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers user_oidc"`
|
||||
OAuth2ClientID *string `json:"oauth2_client_id,omitempty"`
|
||||
OAuth2ClientSecret *string `json:"oauth2_client_secret,omitempty"`
|
||||
OAuth2AuthURL *string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
|
||||
|
||||
@@ -48,7 +48,7 @@ This is an admin-only feature accessible at **Agents** > **Settings** >
|
||||
|
||||
## Authentication
|
||||
|
||||
Each MCP server uses one of four authentication modes. When you change the
|
||||
Each MCP server uses one of five authentication modes. When you change the
|
||||
auth type, fields from the previous type are automatically cleared.
|
||||
|
||||
Secrets are never returned in API responses — boolean flags indicate whether
|
||||
@@ -104,6 +104,21 @@ A static key sent as a header on every request.
|
||||
Arbitrary key-value header pairs sent on every request. At least one header
|
||||
is required when this mode is selected.
|
||||
|
||||
### User OIDC Identity
|
||||
|
||||
Forwards the calling user's OIDC access token (stored in
|
||||
`user_links.oauth_access_token`) to the MCP server as an
|
||||
`Authorization: Bearer <token>` header. The token is refreshed
|
||||
transparently before each request if it has expired or is close to
|
||||
expiring.
|
||||
|
||||
No admin-configurable fields. No per-user connect step.
|
||||
|
||||
**Limitation**: this auth mode only works for users who authenticated to
|
||||
Coder via OIDC. Users who logged in with password or GitHub will see
|
||||
requests sent without an authorization header, and the upstream MCP
|
||||
server is expected to respond with 401.
|
||||
|
||||
## Tool governance
|
||||
|
||||
Control which tools from a server are available in chat:
|
||||
|
||||
Generated
+1
-1
@@ -4579,7 +4579,7 @@ export interface MCPServerConfig {
|
||||
readonly icon_url: string;
|
||||
readonly transport: string; // "streamable_http" or "sse"
|
||||
readonly url: string;
|
||||
readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers"
|
||||
readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers", "user_oidc"
|
||||
/**
|
||||
* OAuth2 fields (only populated for admins).
|
||||
*/
|
||||
|
||||
@@ -698,3 +698,55 @@ export const CustomHeadersAuthType: Story = {
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
export const CreateServerUserOIDC: Story = {
|
||||
args: {
|
||||
serversData: [],
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: /Add your first server/i }),
|
||||
);
|
||||
|
||||
await userEvent.type(
|
||||
await body.findByLabelText(/Display Name/i),
|
||||
"Internal API",
|
||||
);
|
||||
await userEvent.type(
|
||||
body.getByLabelText(/Server URL/i),
|
||||
"https://mcp.internal.example.com/v1",
|
||||
);
|
||||
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: /Authentication/i }),
|
||||
);
|
||||
await userEvent.click(body.getByLabelText(/Authentication/i));
|
||||
await userEvent.click(
|
||||
await body.findByRole("option", { name: /User OIDC Identity/i }),
|
||||
);
|
||||
|
||||
// No additional auth fields for user_oidc; the helper text is shown.
|
||||
expect(
|
||||
body.getByText(/forwarded to this MCP server in the/i),
|
||||
).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(body.getByRole("button", { name: /Create server/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(args.onCreateServer).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
expect(args.onCreateServer).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
auth_type: "user_oidc",
|
||||
}),
|
||||
);
|
||||
// Should not include any oauth2/api_key/custom_headers fields.
|
||||
const call = (args.onCreateServer as ReturnType<typeof fn>).mock
|
||||
.calls[0][0];
|
||||
expect(call).not.toHaveProperty("oauth2_client_id");
|
||||
expect(call).not.toHaveProperty("api_key_value");
|
||||
expect(call).not.toHaveProperty("custom_headers");
|
||||
},
|
||||
};
|
||||
|
||||
@@ -69,6 +69,7 @@ const AUTH_TYPE_OPTIONS = [
|
||||
{ value: "oauth2", label: "OAuth2" },
|
||||
{ value: "api_key", label: "API Key" },
|
||||
{ value: "custom_headers", label: "Custom Headers" },
|
||||
{ value: "user_oidc", label: "User OIDC Identity" },
|
||||
] as const;
|
||||
|
||||
const AVAILABILITY_OPTIONS = [
|
||||
@@ -900,6 +901,22 @@ const ServerForm: FC<ServerFormProps> = ({
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{form.values.authType === "user_oidc" && (
|
||||
<div className="space-y-2 rounded-lg border border-solid border-border/70 bg-surface-secondary/30 p-4 text-xs text-content-secondary">
|
||||
<p className="m-0">
|
||||
The calling user's OIDC access token is forwarded to this
|
||||
MCP server in the <code>Authorization</code> header.
|
||||
Tokens are refreshed transparently before each request.
|
||||
</p>
|
||||
<p className="m-0">
|
||||
Users who did not log in via OIDC (for example, password
|
||||
or GitHub login) will see requests sent without an
|
||||
authorization header. Configure no other fields for this
|
||||
auth type.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CollapsibleContent>
|
||||
</div>
|
||||
</Collapsible>
|
||||
|
||||
Reference in New Issue
Block a user