feat: add MCP server configuration backend for chats (#23227)

## Summary

Adds the database schema, API endpoints, SDK types, and encryption
wrappers for admin-managed MCP (Model Context Protocol) server
configurations that chatd can consume. This is the backend foundation
for allowing external MCP tools (Sentry, Linear, GitHub, etc.) to be
used during AI chat sessions.

## Database

Two new tables:
- **`mcp_server_configs`**: Admin-managed server definitions with URL,
transport (Streamable HTTP / SSE), auth config (none / OAuth2 / API key
/ custom headers), tool allow/deny lists, and an availability policy
(`force_on` / `default_on` / `default_off`). Includes CHECK constraints
on transport, auth_type, and availability values.
- **`mcp_server_user_tokens`**: Per-user OAuth2 tokens for servers
requiring individual authentication. Cascades on user/config deletion.

New column on `chats` table:
- **`mcp_server_ids UUID[]`**: Per-chat MCP server selection, following
the same pattern as `model_config_id` — passed at chat creation,
changeable per-message with nil-means-no-change semantics.

## API Endpoints

All routes are under `/api/experimental/mcp/servers/` and gated behind
the `agents` experiment.

**Admin endpoints** (`ResourceDeploymentConfig` auth):
- `POST /` — Create MCP server config
- `PATCH /{id}` — Update MCP server config (full-replace)
- `DELETE /{id}` — Delete MCP server config

**Authenticated endpoints** (all users, enabled servers only for
non-admins):
- `GET /` — List configs (admins see all, members see enabled-only with
admin fields redacted)
- `GET /{id}` — Get config by ID (with `auth_connected` populated
per-user)

**OAuth2 per-user auth flow:**
- `GET /{id}/oauth2/connect` — Initiate OAuth2 flow (state cookie CSRF
protection)
- `GET /{id}/oauth2/callback` — Handle OAuth2 callback, store tokens
- `DELETE /{id}/oauth2/disconnect` — Remove stored OAuth2 tokens

## Security

- **Secrets never returned**: `OAuth2ClientSecret`, `APIKeyValue`, and
`CustomHeaders` are never in API responses — only boolean indicators
(`has_oauth2_secret`, `has_api_key`, `has_custom_headers`).
- **Field redaction for non-admins**: `convertMCPServerConfigRedacted`
strips `OAuth2ClientID`, auth URLs, scopes, and `APIKeyHeader` from
non-admin responses.
- **dbcrypt encryption at rest**: All 5 secret fields use `dbcrypt_keys`
encryption with full encrypt-on-write / decrypt-on-read wrappers (11
dbcrypt method overrides + 2 helpers), following the same pattern as
`chat_providers.api_key`.
- **OAuth2 CSRF protection**: State parameter stored in `HttpOnly`
cookie with `HTTPCookies.Apply()` for correct `Secure`/`SameSite` behind
TLS-terminating proxies.
- **dbauthz authorization**: All 18 querier methods have authorization
wrappers. Read operations use `ActionRead`, write operations use
`ActionUpdate` on `ResourceDeploymentConfig`.

## Governance Model

| Control | Implementation |
|---------|---------------|
| **Global kill switch** | `enabled` defaults to `false` |
| **Availability policy** | `force_on` (always injected), `default_on`
(pre-selected), `default_off` (opt-in) |
| **Per-chat selection** | `mcp_server_ids` on `CreateChatRequest` /
`CreateChatMessageRequest` |
| **Auth gate** | OAuth2 servers require per-user auth before tools are
injected |
| **Tool-level allow/deny** | Arrays on `mcp_server_configs` for
granular tool filtering |
| **Secrets encrypted at rest** | Uses `dbcrypt_keys` (same pattern as
`chat_providers.api_key`) |

## Tests

8 test functions covering:
- Full CRUD lifecycle (create, list, update, delete)
- Non-admin visibility filtering (enabled-only, field redaction)
- `auth_connected` population for OAuth2 vs non-OAuth2 servers
- Availability policy validation (valid values + invalid rejection)
- Unique slug enforcement (409 Conflict)
- OAuth2 disconnect idempotency
- Chat creation with `mcp_server_ids` persistence

## Known Limitations (Deferred)

These are documented and intentional for an experimental feature:
- **Audit logging** not yet wired — will add when feature stabilizes
- **Cross-field validation** (e.g., OAuth2 fields required when
`auth_type=oauth2`) — admin-only endpoint, will add when stabilizing
- **`force_on` auto-injection** — query exists but not yet wired into
chatd tool injection (follow-up)
- **Additional test coverage** — 403 auth tests, GET-by-ID tests,
callback CSRF tests planned for follow-up

## What's NOT in this PR

- Frontend UI (admin panel + chat picker)
- Actual MCP client connections (`chatd/chatmcp/` manager)
- Tool injection into `chatloop/`
This commit is contained in:
Kyle Carberry
2026-03-19 10:07:36 -04:00
committed by GitHub
parent 8f78c5145f
commit d8ff67fb68
39 changed files with 4300 additions and 30 deletions
+20
View File
@@ -379,6 +379,7 @@ type CreateOptions struct {
ChatMode database.NullChatMode
SystemPrompt string
InitialUserContent []codersdk.ChatMessagePart
MCPServerIDs []uuid.UUID
}
// SendMessageBusyBehavior controls what happens when a chat is already active.
@@ -401,6 +402,7 @@ type SendMessageOptions struct {
Content []codersdk.ChatMessagePart
ModelConfigID *uuid.UUID
BusyBehavior SendMessageBusyBehavior
MCPServerIDs *[]uuid.UUID
}
// SendMessageResult contains the outcome of user message processing.
@@ -450,6 +452,12 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
if len(opts.InitialUserContent) == 0 {
return database.Chat{}, xerrors.New("initial user content is required")
}
// Ensure MCPServerIDs is non-nil so pq.Array produces '{}'
// instead of SQL NULL, which violates the NOT NULL column
// constraint.
if opts.MCPServerIDs == nil {
opts.MCPServerIDs = []uuid.UUID{}
}
var chat database.Chat
txErr := p.db.InTx(func(tx database.Store) error {
@@ -465,6 +473,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
LastModelConfigID: opts.ModelConfigID,
Title: opts.Title,
Mode: opts.ChatMode,
MCPServerIDs: opts.MCPServerIDs,
})
if err != nil {
return xerrors.Errorf("insert chat: %w", err)
@@ -596,6 +605,17 @@ func (p *Server) SendMessage(
modelConfigID = *opts.ModelConfigID
}
// Update MCP server IDs on the chat when explicitly provided.
if opts.MCPServerIDs != nil {
lockedChat, err = tx.UpdateChatMCPServerIDs(ctx, database.UpdateChatMCPServerIDsParams{
ID: opts.ChatID,
MCPServerIDs: *opts.MCPServerIDs,
})
if err != nil {
return xerrors.Errorf("update chat mcp server ids: %w", err)
}
}
existingQueued, err := tx.GetChatQueuedMessages(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("get queued messages: %w", err)
+72
View File
@@ -284,6 +284,41 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
// Validate MCP server IDs exist.
if len(req.MCPServerIDs) > 0 {
//nolint:gocritic // Need to validate MCP server IDs exist.
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to validate MCP server IDs.",
Detail: err.Error(),
})
return
}
if len(existingConfigs) != len(req.MCPServerIDs) {
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
for _, c := range existingConfigs {
found[c.ID] = struct{}{}
}
var missing []string
for _, id := range req.MCPServerIDs {
if _, ok := found[id]; !ok {
missing = append(missing, id.String())
}
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "One or more MCP server IDs are invalid.",
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
})
return
}
}
mcpServerIDs := req.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
OwnerID: apiKey.UserID,
WorkspaceID: workspaceSelection.WorkspaceID,
@@ -291,6 +326,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
ModelConfigID: modelConfigID,
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
InitialUserContent: contentBlocks,
MCPServerIDs: mcpServerIDs,
})
if err != nil {
if maybeWriteLimitErr(ctx, rw, err) {
@@ -1456,6 +1492,36 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
return
}
// Validate MCP server IDs exist.
if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 {
//nolint:gocritic // Need to validate MCP server IDs exist.
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to validate MCP server IDs.",
Detail: err.Error(),
})
return
}
if len(existingConfigs) != len(*req.MCPServerIDs) {
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
for _, c := range existingConfigs {
found[c.ID] = struct{}{}
}
var missing []string
for _, id := range *req.MCPServerIDs {
if _, ok := found[id]; !ok {
missing = append(missing, id.String())
}
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "One or more MCP server IDs are invalid.",
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
})
return
}
}
sendResult, sendErr := api.chatDaemon.SendMessage(
ctx,
chatd.SendMessageOptions{
@@ -1464,6 +1530,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
Content: contentBlocks,
ModelConfigID: req.ModelConfigID,
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
MCPServerIDs: req.MCPServerIDs,
},
)
if sendErr != nil {
@@ -2979,6 +3046,10 @@ func truncateRunes(value string, maxLen int) string {
}
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
mcpServerIDs := c.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
chat := codersdk.Chat{
ID: c.ID,
OwnerID: c.OwnerID,
@@ -2988,6 +3059,7 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
Archived: c.Archived,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
MCPServerIDs: mcpServerIDs,
}
if c.LastError.Valid {
chat.LastError = &c.LastError.String
+19 -2
View File
@@ -1233,10 +1233,27 @@ func New(options *Options) *API {
r.Route("/mcp", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
)
// MCP server configuration endpoints.
r.Route("/servers", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents))
r.Get("/", api.listMCPServerConfigs)
r.Post("/", api.createMCPServerConfig)
r.Route("/{mcpServer}", func(r chi.Router) {
r.Get("/", api.getMCPServerConfig)
r.Patch("/", api.updateMCPServerConfig)
r.Delete("/", api.deleteMCPServerConfig)
// OAuth2 user flow
r.Get("/oauth2/connect", api.mcpServerOAuth2Connect)
r.Get("/oauth2/callback", api.mcpServerOAuth2Callback)
r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect)
})
})
// MCP HTTP transport endpoint with mandatory authentication
r.Mount("/http", api.mcpHTTPHandler())
r.Route("/http", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP))
r.Mount("/", api.mcpHTTPHandler())
})
})
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
r.Use(
+3
View File
@@ -20,6 +20,9 @@ const (
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs
CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs
CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
+109
View File
@@ -1691,6 +1691,13 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
return q.db.CleanTailnetTunnels(ctx)
}
func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -1920,6 +1927,20 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
return id, nil
}
func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerConfigByID(ctx, id)
}
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerUserToken(ctx, arg)
}
func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil {
return err
@@ -2763,6 +2784,13 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledMCPServerConfigs(ctx)
}
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
}
@@ -2821,6 +2849,13 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg)
}
func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetForcedMCPServerConfigs(ctx)
}
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
}
@@ -2961,6 +2996,48 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}
func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigByID(ctx, id)
}
func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigBySlug(ctx, slug)
}
func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigs(ctx)
}
func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
}
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.GetMCPServerUserToken(ctx, arg)
}
func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerUserTokensByUserID(ctx, userID)
}
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
return nil, err
@@ -4727,6 +4804,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
return q.db.InsertLicense(ctx, arg)
}
func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.InsertMCPServerConfig(ctx, arg)
}
func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
return database.WorkspaceAgentMemoryResourceMonitor{}, err
@@ -5486,6 +5570,17 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatMCPServerIDs(ctx, arg)
}
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
// Authorize update on the parent chat of the edited message.
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
@@ -5650,6 +5745,13 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da
return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args)
}
func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.UpdateMCPServerConfig(ctx, arg)
}
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
@@ -6695,6 +6797,13 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
return q.db.UpsertLogoURL(ctx, value)
}
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.UpsertMCPServerUserToken(ctx, arg)
}
func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
+110 -2
View File
@@ -1005,6 +1005,114 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
slug := "test-mcp-server"
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug})
dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes()
check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
ids := []uuid.UUID{configA.ID, configB.ID}
dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token)
}))
s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
userID := uuid.New()
tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})}
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
}))
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertMCPServerConfigParams{
DisplayName: "Test MCP Server",
Slug: "test-mcp-server",
}
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug})
dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatMCPServerIDsParams{
ID: chat.ID,
MCPServerIDs: []uuid.UUID{uuid.New()},
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
arg := database.UpdateMCPServerConfigParams{
ID: config.ID,
DisplayName: "Updated MCP Server",
Slug: "updated-mcp-server",
}
dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
AccessToken: "test-access-token",
TokenType: "bearer",
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1381,8 +1489,8 @@ func (s *MethodTestSuite) TestLicense() {
check.Args().Asserts().Returns("value")
}))
s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"})
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"})
}))
s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes()
+120
View File
@@ -264,6 +264,14 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
start := time.Now()
r0 := m.s.CleanupDeletedMCPServerIDsFromChats(ctx)
m.queryLatencies.WithLabelValues("CleanupDeletedMCPServerIDsFromChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CleanupDeletedMCPServerIDsFromChats").Inc()
return r0
}
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
@@ -480,6 +488,22 @@ func (m queryMetricsStore) DeleteLicense(ctx context.Context, id int32) (int32,
return r0, r1
}
func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteMCPServerConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteMCPServerConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerConfigByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
start := time.Now()
r0 := m.s.DeleteMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserToken").Inc()
return r0
}
func (m queryMetricsStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteOAuth2ProviderAppByClientID(ctx, id)
@@ -1328,6 +1352,14 @@ func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]datab
return r0, r1
}
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetEnabledMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
start := time.Now()
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
@@ -1384,6 +1416,14 @@ func (m queryMetricsStore) GetFilteredInboxNotificationsByUserID(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetForcedMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetForcedMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetForcedMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
start := time.Now()
r0, r1 := m.s.GetGitSSHKey(ctx, userID)
@@ -1544,6 +1584,54 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) {
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("GetMCPServerConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigBySlug(ctx, slug)
m.queryLatencies.WithLabelValues("GetMCPServerConfigBySlug").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigBySlug").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigsByIDs(ctx, ids)
m.queryLatencies.WithLabelValues("GetMCPServerConfigsByIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigsByIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("GetMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserToken").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserTokensByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetMCPServerUserTokensByUserID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserTokensByUserID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
start := time.Now()
r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg)
@@ -3184,6 +3272,14 @@ func (m queryMetricsStore) InsertLicense(ctx context.Context, arg database.Inser
return r0, r1
}
func (m queryMetricsStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.InsertMCPServerConfig(ctx, arg)
m.queryLatencies.WithLabelValues("InsertMCPServerConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertMCPServerConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
start := time.Now()
r0, r1 := m.s.InsertMemoryResourceMonitor(ctx, arg)
@@ -3856,6 +3952,14 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatMCPServerIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMCPServerIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
@@ -3960,6 +4064,14 @@ func (m queryMetricsStore) UpdateInboxNotificationReadStatus(ctx context.Context
return r0
}
func (m queryMetricsStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateMCPServerConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateMCPServerConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateMCPServerConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
start := time.Now()
r0, r1 := m.s.UpdateMemberRoles(ctx, arg)
@@ -4696,6 +4808,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro
return r0
}
func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserToken").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
start := time.Now()
r0 := m.s.UpsertNotificationReportGeneratorLog(ctx, arg)
+222
View File
@@ -334,6 +334,20 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
}
// CleanupDeletedMCPServerIDsFromChats mocks base method.
func (m *MockStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupDeletedMCPServerIDsFromChats", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupDeletedMCPServerIDsFromChats indicates an expected call of CleanupDeletedMCPServerIDsFromChats.
func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
}
// CountAIBridgeInterceptions mocks base method.
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -769,6 +783,34 @@ func (mr *MockStoreMockRecorder) DeleteLicense(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLicense", reflect.TypeOf((*MockStore)(nil).DeleteLicense), ctx, id)
}
// DeleteMCPServerConfigByID mocks base method.
func (m *MockStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMCPServerConfigByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMCPServerConfigByID indicates an expected call of DeleteMCPServerConfigByID.
func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id)
}
// DeleteMCPServerUserToken mocks base method.
func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMCPServerUserToken indicates an expected call of DeleteMCPServerUserToken.
func (mr *MockStoreMockRecorder) DeleteMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserToken), ctx, arg)
}
// DeleteOAuth2ProviderAppByClientID mocks base method.
func (m *MockStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -2437,6 +2479,21 @@ func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
}
// GetEnabledMCPServerConfigs mocks base method.
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEnabledMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEnabledMCPServerConfigs indicates an expected call of GetEnabledMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetEnabledMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledMCPServerConfigs), ctx)
}
// GetExternalAuthLink mocks base method.
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
m.ctrl.T.Helper()
@@ -2542,6 +2599,21 @@ func (mr *MockStoreMockRecorder) GetFilteredInboxNotificationsByUserID(ctx, arg
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetFilteredInboxNotificationsByUserID), ctx, arg)
}
// GetForcedMCPServerConfigs mocks base method.
func (m *MockStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetForcedMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetForcedMCPServerConfigs indicates an expected call of GetForcedMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetForcedMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForcedMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetForcedMCPServerConfigs), ctx)
}
// GetGitSSHKey mocks base method.
func (m *MockStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
m.ctrl.T.Helper()
@@ -2842,6 +2914,96 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx)
}
// GetMCPServerConfigByID mocks base method.
func (m *MockStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigByID", ctx, id)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigByID indicates an expected call of GetMCPServerConfigByID.
func (mr *MockStoreMockRecorder) GetMCPServerConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigByID), ctx, id)
}
// GetMCPServerConfigBySlug mocks base method.
func (m *MockStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigBySlug", ctx, slug)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigBySlug indicates an expected call of GetMCPServerConfigBySlug.
func (mr *MockStoreMockRecorder) GetMCPServerConfigBySlug(ctx, slug any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigBySlug", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigBySlug), ctx, slug)
}
// GetMCPServerConfigs mocks base method.
func (m *MockStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigs indicates an expected call of GetMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigs), ctx)
}
// GetMCPServerConfigsByIDs mocks base method.
func (m *MockStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigsByIDs", ctx, ids)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigsByIDs indicates an expected call of GetMCPServerConfigsByIDs.
func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids)
}
// GetMCPServerUserToken mocks base method.
func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserToken indicates an expected call of GetMCPServerUserToken.
func (mr *MockStoreMockRecorder) GetMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserToken), ctx, arg)
}
// GetMCPServerUserTokensByUserID mocks base method.
func (m *MockStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserTokensByUserID", ctx, userID)
ret0, _ := ret[0].([]database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserTokensByUserID indicates an expected call of GetMCPServerUserTokensByUserID.
func (mr *MockStoreMockRecorder) GetMCPServerUserTokensByUserID(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserTokensByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserTokensByUserID), ctx, userID)
}
// GetNotificationMessagesByStatus mocks base method.
func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
m.ctrl.T.Helper()
@@ -5957,6 +6119,21 @@ func (mr *MockStoreMockRecorder) InsertLicense(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), ctx, arg)
}
// InsertMCPServerConfig mocks base method.
func (m *MockStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertMCPServerConfig", ctx, arg)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMCPServerConfig indicates an expected call of InsertMCPServerConfig.
func (mr *MockStoreMockRecorder) InsertMCPServerConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMCPServerConfig", reflect.TypeOf((*MockStore)(nil).InsertMCPServerConfig), ctx, arg)
}
// InsertMemoryResourceMonitor mocks base method.
func (m *MockStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
m.ctrl.T.Helper()
@@ -7256,6 +7433,21 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
}
// UpdateChatMCPServerIDs mocks base method.
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatMCPServerIDs", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatMCPServerIDs indicates an expected call of UpdateChatMCPServerIDs.
func (mr *MockStoreMockRecorder) UpdateChatMCPServerIDs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMCPServerIDs", reflect.TypeOf((*MockStore)(nil).UpdateChatMCPServerIDs), ctx, arg)
}
// UpdateChatMessageByID mocks base method.
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -7449,6 +7641,21 @@ func (mr *MockStoreMockRecorder) UpdateInboxNotificationReadStatus(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInboxNotificationReadStatus", reflect.TypeOf((*MockStore)(nil).UpdateInboxNotificationReadStatus), ctx, arg)
}
// UpdateMCPServerConfig mocks base method.
func (m *MockStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateMCPServerConfig", ctx, arg)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateMCPServerConfig indicates an expected call of UpdateMCPServerConfig.
func (mr *MockStoreMockRecorder) UpdateMCPServerConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateMCPServerConfig), ctx, arg)
}
// UpdateMemberRoles mocks base method.
func (m *MockStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
m.ctrl.T.Helper()
@@ -8773,6 +8980,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value)
}
// UpsertMCPServerUserToken mocks base method.
func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertMCPServerUserToken indicates an expected call of UpsertMCPServerUserToken.
func (mr *MockStoreMockRecorder) UpsertMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserToken), ctx, arg)
}
// UpsertNotificationReportGeneratorLog mocks base method.
func (m *MockStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
m.ctrl.T.Helper()
+94 -1
View File
@@ -1393,7 +1393,8 @@ CREATE TABLE chats (
last_model_config_id uuid NOT NULL,
archived boolean DEFAULT false NOT NULL,
last_error text,
mode chat_mode
mode chat_mode,
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL
);
CREATE TABLE connection_logs (
@@ -1670,6 +1671,53 @@ CREATE SEQUENCE licenses_id_seq
ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id;
CREATE TABLE mcp_server_configs (
id uuid DEFAULT gen_random_uuid() NOT NULL,
display_name text NOT NULL,
slug text NOT NULL,
description text DEFAULT ''::text NOT NULL,
icon_url text DEFAULT ''::text NOT NULL,
transport text DEFAULT 'streamable_http'::text NOT NULL,
url text NOT NULL,
auth_type text DEFAULT 'none'::text NOT NULL,
oauth2_client_id text DEFAULT ''::text NOT NULL,
oauth2_client_secret text DEFAULT ''::text NOT NULL,
oauth2_client_secret_key_id text,
oauth2_auth_url text DEFAULT ''::text NOT NULL,
oauth2_token_url text DEFAULT ''::text NOT NULL,
oauth2_scopes text DEFAULT ''::text NOT NULL,
api_key_header text DEFAULT 'Authorization'::text NOT NULL,
api_key_value text DEFAULT ''::text NOT NULL,
api_key_value_key_id text,
custom_headers text DEFAULT '{}'::text NOT NULL,
custom_headers_key_id text,
tool_allow_list text[] DEFAULT '{}'::text[] NOT NULL,
tool_deny_list text[] DEFAULT '{}'::text[] NOT NULL,
availability text DEFAULT 'default_off'::text NOT NULL,
enabled boolean DEFAULT false NOT NULL,
created_by uuid,
updated_by uuid,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
);
CREATE TABLE mcp_server_user_tokens (
id uuid DEFAULT gen_random_uuid() NOT NULL,
mcp_server_config_id uuid NOT NULL,
user_id uuid NOT NULL,
access_token text NOT NULL,
access_token_key_id text,
refresh_token text DEFAULT ''::text NOT NULL,
refresh_token_key_id text,
token_type text DEFAULT 'Bearer'::text NOT NULL,
expiry timestamp with time zone,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE TABLE notification_messages (
id uuid NOT NULL,
notification_template_id uuid NOT NULL,
@@ -3343,6 +3391,18 @@ ALTER TABLE ONLY licenses
ALTER TABLE ONLY licenses
ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
ALTER TABLE ONLY notification_messages
ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
@@ -3691,6 +3751,12 @@ CREATE INDEX idx_inbox_notifications_user_id_read_at ON inbox_notifications USIN
CREATE INDEX idx_inbox_notifications_user_id_template_id_targets ON inbox_notifications USING btree (user_id, template_id, targets);
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (enabled) WHERE (enabled = true);
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text));
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id);
CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status);
CREATE INDEX idx_organization_member_organization_id_uuid ON organization_members USING btree (organization_id);
@@ -4015,6 +4081,33 @@ ALTER TABLE ONLY jfrog_xray_scans
ALTER TABLE ONLY jfrog_xray_scans
ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY notification_messages
ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
@@ -40,6 +40,15 @@ const (
ForeignKeyInboxNotificationsUserID ForeignKeyConstraint = "inbox_notifications_user_id_fkey" // ALTER TABLE ONLY inbox_notifications ADD CONSTRAINT inbox_notifications_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyMcpServerConfigsAPIKeyValueKeyID ForeignKeyConstraint = "mcp_server_configs_api_key_value_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsCreatedBy ForeignKeyConstraint = "mcp_server_configs_created_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerUserTokensUserID ForeignKeyConstraint = "mcp_server_user_tokens_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyNotificationMessagesNotificationTemplateID ForeignKeyConstraint = "notification_messages_notification_template_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
ForeignKeyNotificationMessagesUserID ForeignKeyConstraint = "notification_messages_user_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyNotificationPreferencesNotificationTemplateID ForeignKeyConstraint = "notification_preferences_notification_template_id_fkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
@@ -0,0 +1,6 @@
ALTER TABLE chats DROP COLUMN IF EXISTS mcp_server_ids;
DROP INDEX IF EXISTS idx_mcp_server_configs_enabled;
DROP INDEX IF EXISTS idx_mcp_server_configs_forced;
DROP INDEX IF EXISTS idx_mcp_server_user_tokens_user_id;
DROP TABLE IF EXISTS mcp_server_user_tokens;
DROP TABLE IF EXISTS mcp_server_configs;
@@ -0,0 +1,75 @@
CREATE TABLE mcp_server_configs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
-- Display
display_name TEXT NOT NULL,
slug TEXT NOT NULL UNIQUE,
description TEXT NOT NULL DEFAULT '',
icon_url TEXT NOT NULL DEFAULT '',
-- Connection
transport TEXT NOT NULL DEFAULT 'streamable_http'
CHECK (transport IN ('streamable_http', 'sse')),
url TEXT NOT NULL,
-- Authentication
auth_type TEXT NOT NULL DEFAULT 'none'
CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers')),
-- OAuth2 config (when auth_type = 'oauth2')
oauth2_client_id TEXT NOT NULL DEFAULT '',
oauth2_client_secret TEXT NOT NULL DEFAULT '',
oauth2_client_secret_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
oauth2_auth_url TEXT NOT NULL DEFAULT '',
oauth2_token_url TEXT NOT NULL DEFAULT '',
oauth2_scopes TEXT NOT NULL DEFAULT '',
-- API key config (when auth_type = 'api_key')
api_key_header TEXT NOT NULL DEFAULT 'Authorization',
api_key_value TEXT NOT NULL DEFAULT '',
api_key_value_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
-- Custom headers (when auth_type = 'custom_headers')
custom_headers TEXT NOT NULL DEFAULT '{}',
custom_headers_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
-- Tool governance
tool_allow_list TEXT[] NOT NULL DEFAULT '{}',
tool_deny_list TEXT[] NOT NULL DEFAULT '{}',
-- Availability policy
availability TEXT NOT NULL DEFAULT 'default_off'
CHECK (availability IN ('force_on', 'default_on', 'default_off')),
-- Lifecycle
enabled BOOLEAN NOT NULL DEFAULT false,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
updated_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE mcp_server_user_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
mcp_server_config_id UUID NOT NULL REFERENCES mcp_server_configs(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
access_token TEXT NOT NULL,
access_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
refresh_token TEXT NOT NULL DEFAULT '',
refresh_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
token_type TEXT NOT NULL DEFAULT 'Bearer',
expiry TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
UNIQUE (mcp_server_config_id, user_id)
);
-- Add MCP server selection to chats (per-chat, like model_config_id)
ALTER TABLE chats ADD COLUMN mcp_server_ids UUID[] NOT NULL DEFAULT '{}';
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs(enabled) WHERE enabled = TRUE;
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs(enabled, availability) WHERE enabled = TRUE AND availability = 'force_on';
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens(user_id);
@@ -0,0 +1,48 @@
INSERT INTO mcp_server_configs (
id,
display_name,
slug,
url,
transport,
auth_type,
availability,
enabled,
created_by,
updated_by,
created_at,
updated_at
) VALUES (
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
'Fixture MCP Server',
'fixture-mcp-server',
'https://mcp.example.com/sse',
'sse',
'none',
'default_on',
TRUE,
'30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com
'30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00'
);
INSERT INTO mcp_server_user_tokens (
id,
mcp_server_config_id,
user_id,
access_token,
token_type,
created_at,
updated_at
)
SELECT
'b2c3d4e5-f6a7-8901-bcde-f12345678901',
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
id,
'fixture-access-token',
'Bearer',
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:00+00'
FROM users
ORDER BY created_at, id
LIMIT 1;
+1
View File
@@ -787,6 +787,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
); err != nil {
return nil, err
}
+45
View File
@@ -4167,6 +4167,7 @@ type Chat struct {
Archived bool `db:"archived" json:"archived"`
LastError sql.NullString `db:"last_error" json:"last_error"`
Mode NullChatMode `db:"mode" json:"mode"`
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
}
type ChatDiffStatus struct {
@@ -4452,6 +4453,50 @@ type License struct {
UUID uuid.UUID `db:"uuid" json:"uuid"`
}
type MCPServerConfig struct {
ID uuid.UUID `db:"id" json:"id"`
DisplayName string `db:"display_name" json:"display_name"`
Slug string `db:"slug" json:"slug"`
Description string `db:"description" json:"description"`
IconURL string `db:"icon_url" json:"icon_url"`
Transport string `db:"transport" json:"transport"`
Url string `db:"url" json:"url"`
AuthType string `db:"auth_type" json:"auth_type"`
OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"`
OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"`
OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"`
OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"`
OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"`
OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"`
APIKeyHeader string `db:"api_key_header" json:"api_key_header"`
APIKeyValue string `db:"api_key_value" json:"api_key_value"`
APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"`
CustomHeaders string `db:"custom_headers" json:"custom_headers"`
CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"`
ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"`
ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"`
Availability string `db:"availability" json:"availability"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type MCPServerUserToken struct {
ID uuid.UUID `db:"id" json:"id"`
MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
AccessToken string `db:"access_token" json:"access_token"`
AccessTokenKeyID sql.NullString `db:"access_token_key_id" json:"access_token_key_id"`
RefreshToken string `db:"refresh_token" json:"refresh_token"`
RefreshTokenKeyID sql.NullString `db:"refresh_token_key_id" json:"refresh_token_key_id"`
TokenType string `db:"token_type" json:"token_type"`
Expiry sql.NullTime `db:"expiry" json:"expiry"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type NotificationMessage struct {
ID uuid.UUID `db:"id" json:"id"`
NotificationTemplateID uuid.UUID `db:"notification_template_id" json:"notification_template_id"`
+15
View File
@@ -74,6 +74,7 @@ type sqlcQuerier interface {
CleanTailnetCoordinators(ctx context.Context) error
CleanTailnetLostPeers(ctx context.Context) error
CleanTailnetTunnels(ctx context.Context) error
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
@@ -110,6 +111,8 @@ type sqlcQuerier interface {
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error
DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error
DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error
DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error
DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error
@@ -269,6 +272,7 @@ type sqlcQuerier interface {
GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error)
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error)
@@ -284,6 +288,7 @@ type sqlcQuerier interface {
// param created_at_opt: The created_at timestamp to filter by. This parameter is usd for pagination - it fetches notifications created before the specified timestamp if it is not the zero value
// param limit_opt: The limit of notifications to fetch. If the limit is not specified, it defaults to 25
GetFilteredInboxNotificationsByUserID(ctx context.Context, arg GetFilteredInboxNotificationsByUserIDParams) ([]InboxNotification, error)
GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
@@ -312,6 +317,12 @@ type sqlcQuerier interface {
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error)
GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error)
GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error)
GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error)
GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error)
GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error)
// Fetch the notification report generator log indicating recent activity.
GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error)
@@ -675,6 +686,7 @@ type sqlcQuerier interface {
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
InsertInboxNotification(ctx context.Context, arg InsertInboxNotificationParams) (InboxNotification, error)
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
InsertMCPServerConfig(ctx context.Context, arg InsertMCPServerConfigParams) (MCPServerConfig, error)
InsertMemoryResourceMonitor(ctx context.Context, arg InsertMemoryResourceMonitorParams) (WorkspaceAgentMemoryResourceMonitor, error)
// Inserts any group by name that does not exist. All new groups are given
// a random uuid, are inserted into the same organization. They have the default
@@ -796,6 +808,7 @@ type sqlcQuerier interface {
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
@@ -813,6 +826,7 @@ type sqlcQuerier interface {
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error)
UpdateInboxNotificationReadStatus(ctx context.Context, arg UpdateInboxNotificationReadStatusParams) error
UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPServerConfigParams) (MCPServerConfig, error)
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
UpdateMemoryResourceMonitor(ctx context.Context, arg UpdateMemoryResourceMonitorParams) error
UpdateNotificationTemplateMethodByID(ctx context.Context, arg UpdateNotificationTemplateMethodByIDParams) (NotificationTemplate, error)
@@ -917,6 +931,7 @@ type sqlcQuerier interface {
UpsertHealthSettings(ctx context.Context, value string) error
UpsertLastUpdateCheck(ctx context.Context, value string) error
UpsertLogoURL(ctx context.Context, value string) error
UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error)
// Insert or update notification report generator logs with recent activity.
UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error
UpsertNotificationsSettings(ctx context.Context, value string) error
+5 -5
View File
@@ -1653,12 +1653,12 @@ func TestDefaultProxy(t *testing.T) {
require.NoError(t, err, "get def proxy")
require.Equal(t, defProxy.DisplayName, "Default")
require.Equal(t, defProxy.IconUrl, "/emojis/1f3e1.png")
require.Equal(t, defProxy.IconURL, "/emojis/1f3e1.png")
// Set the proxy values
args := database.UpsertDefaultProxyParams{
DisplayName: "displayname",
IconUrl: "/icon.png",
IconURL: "/icon.png",
}
err = db.UpsertDefaultProxy(ctx, args)
require.NoError(t, err, "insert def proxy")
@@ -1666,12 +1666,12 @@ func TestDefaultProxy(t *testing.T) {
defProxy, err = db.GetDefaultProxyConfig(ctx)
require.NoError(t, err, "get def proxy")
require.Equal(t, defProxy.DisplayName, args.DisplayName)
require.Equal(t, defProxy.IconUrl, args.IconUrl)
require.Equal(t, defProxy.IconURL, args.IconURL)
// Upsert values
args = database.UpsertDefaultProxyParams{
DisplayName: "newdisplayname",
IconUrl: "/newicon.png",
IconURL: "/newicon.png",
}
err = db.UpsertDefaultProxy(ctx, args)
require.NoError(t, err, "upsert def proxy")
@@ -1679,7 +1679,7 @@ func TestDefaultProxy(t *testing.T) {
defProxy, err = db.GetDefaultProxyConfig(ctx)
require.NoError(t, err, "get def proxy")
require.Equal(t, defProxy.DisplayName, args.DisplayName)
require.Equal(t, defProxy.IconUrl, args.IconUrl)
require.Equal(t, defProxy.IconURL, args.IconURL)
// Ensure other site configs are the same
found, err := db.GetDeploymentID(ctx)
File diff suppressed because it is too large Load Diff
+15 -2
View File
@@ -180,7 +180,8 @@ INSERT INTO chats (
root_chat_id,
last_model_config_id,
title,
mode
mode,
mcp_server_ids
) VALUES (
@owner_id::uuid,
sqlc.narg('workspace_id')::uuid,
@@ -188,7 +189,8 @@ INSERT INTO chats (
sqlc.narg('root_chat_id')::uuid,
@last_model_config_id::uuid,
@title::text,
sqlc.narg('mode')::chat_mode
sqlc.narg('mode')::chat_mode,
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[])
)
RETURNING
*;
@@ -295,6 +297,17 @@ WHERE
RETURNING
*;
-- name: UpdateChatMCPServerIDs :one
UPDATE
chats
SET
mcp_server_ids = @mcp_server_ids::uuid[],
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: AcquireChats :many
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
-- to prevent multiple replicas from acquiring the same chat.
@@ -0,0 +1,213 @@
-- name: GetMCPServerConfigByID :one
SELECT
*
FROM
mcp_server_configs
WHERE
id = @id::uuid;
-- name: GetMCPServerConfigBySlug :one
SELECT
*
FROM
mcp_server_configs
WHERE
slug = @slug::text;
-- name: GetMCPServerConfigs :many
SELECT
*
FROM
mcp_server_configs
ORDER BY
display_name ASC;
-- name: GetEnabledMCPServerConfigs :many
SELECT
*
FROM
mcp_server_configs
WHERE
enabled = TRUE
ORDER BY
display_name ASC;
-- name: GetMCPServerConfigsByIDs :many
SELECT
*
FROM
mcp_server_configs
WHERE
id = ANY(@ids::uuid[])
ORDER BY
display_name ASC;
-- name: GetForcedMCPServerConfigs :many
SELECT
*
FROM
mcp_server_configs
WHERE
enabled = TRUE
AND availability = 'force_on'
ORDER BY
display_name ASC;
-- name: InsertMCPServerConfig :one
INSERT INTO mcp_server_configs (
display_name,
slug,
description,
icon_url,
transport,
url,
auth_type,
oauth2_client_id,
oauth2_client_secret,
oauth2_client_secret_key_id,
oauth2_auth_url,
oauth2_token_url,
oauth2_scopes,
api_key_header,
api_key_value,
api_key_value_key_id,
custom_headers,
custom_headers_key_id,
tool_allow_list,
tool_deny_list,
availability,
enabled,
created_by,
updated_by
) VALUES (
@display_name::text,
@slug::text,
@description::text,
@icon_url::text,
@transport::text,
@url::text,
@auth_type::text,
@oauth2_client_id::text,
@oauth2_client_secret::text,
sqlc.narg('oauth2_client_secret_key_id')::text,
@oauth2_auth_url::text,
@oauth2_token_url::text,
@oauth2_scopes::text,
@api_key_header::text,
@api_key_value::text,
sqlc.narg('api_key_value_key_id')::text,
@custom_headers::text,
sqlc.narg('custom_headers_key_id')::text,
@tool_allow_list::text[],
@tool_deny_list::text[],
@availability::text,
@enabled::boolean,
@created_by::uuid,
@updated_by::uuid
)
RETURNING
*;
-- name: UpdateMCPServerConfig :one
UPDATE
mcp_server_configs
SET
display_name = @display_name::text,
slug = @slug::text,
description = @description::text,
icon_url = @icon_url::text,
transport = @transport::text,
url = @url::text,
auth_type = @auth_type::text,
oauth2_client_id = @oauth2_client_id::text,
oauth2_client_secret = @oauth2_client_secret::text,
oauth2_client_secret_key_id = sqlc.narg('oauth2_client_secret_key_id')::text,
oauth2_auth_url = @oauth2_auth_url::text,
oauth2_token_url = @oauth2_token_url::text,
oauth2_scopes = @oauth2_scopes::text,
api_key_header = @api_key_header::text,
api_key_value = @api_key_value::text,
api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text,
custom_headers = @custom_headers::text,
custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text,
tool_allow_list = @tool_allow_list::text[],
tool_deny_list = @tool_deny_list::text[],
availability = @availability::text,
enabled = @enabled::boolean,
updated_by = @updated_by::uuid,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: DeleteMCPServerConfigByID :exec
DELETE FROM
mcp_server_configs
WHERE
id = @id::uuid;
-- name: GetMCPServerUserToken :one
SELECT
*
FROM
mcp_server_user_tokens
WHERE
mcp_server_config_id = @mcp_server_config_id::uuid
AND user_id = @user_id::uuid;
-- name: GetMCPServerUserTokensByUserID :many
SELECT
*
FROM
mcp_server_user_tokens
WHERE
user_id = @user_id::uuid;
-- name: UpsertMCPServerUserToken :one
INSERT INTO mcp_server_user_tokens (
mcp_server_config_id,
user_id,
access_token,
access_token_key_id,
refresh_token,
refresh_token_key_id,
token_type,
expiry
) VALUES (
@mcp_server_config_id::uuid,
@user_id::uuid,
@access_token::text,
sqlc.narg('access_token_key_id')::text,
@refresh_token::text,
sqlc.narg('refresh_token_key_id')::text,
@token_type::text,
sqlc.narg('expiry')::timestamptz
)
ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET
access_token = @access_token::text,
access_token_key_id = sqlc.narg('access_token_key_id')::text,
refresh_token = @refresh_token::text,
refresh_token_key_id = sqlc.narg('refresh_token_key_id')::text,
token_type = @token_type::text,
expiry = sqlc.narg('expiry')::timestamptz,
updated_at = NOW()
RETURNING
*;
-- name: DeleteMCPServerUserToken :exec
DELETE FROM
mcp_server_user_tokens
WHERE
mcp_server_config_id = @mcp_server_config_id::uuid
AND user_id = @user_id::uuid;
-- name: CleanupDeletedMCPServerIDsFromChats :exec
UPDATE chats
SET mcp_server_ids = (
SELECT COALESCE(array_agg(sid), '{}')
FROM unnest(chats.mcp_server_ids) AS sid
WHERE sid IN (SELECT id FROM mcp_server_configs)
)
WHERE mcp_server_ids != '{}'
AND NOT (mcp_server_ids <@ COALESCE((SELECT array_agg(id) FROM mcp_server_configs), '{}'));
+22
View File
@@ -236,6 +236,28 @@ sql:
aibridge_token_usage: AIBridgeTokenUsage
aibridge_user_prompt: AIBridgeUserPrompt
aibridge_model_thought: AIBridgeModelThought
mcp_server_config: MCPServerConfig
mcp_server_configs: MCPServerConfigs
mcp_server_user_token: MCPServerUserToken
mcp_server_user_tokens: MCPServerUserTokens
mcp_server_tool_snapshot: MCPServerToolSnapshot
mcp_server_tool_snapshots: MCPServerToolSnapshots
mcp_server_config_id: MCPServerConfigID
mcp_server_ids: MCPServerIDs
icon_url: IconURL
oauth2_client_id: OAuth2ClientID
oauth2_client_secret: OAuth2ClientSecret
oauth2_client_secret_key_id: OAuth2ClientSecretKeyID
oauth2_auth_url: OAuth2AuthURL
oauth2_token_url: OAuth2TokenURL
oauth2_scopes: OAuth2Scopes
api_key_header: APIKeyHeader
api_key_value: APIKeyValue
api_key_value_key_id: APIKeyValueKeyID
custom_headers_key_id: CustomHeadersKeyID
tools_json: ToolsJSON
access_token_key_id: AccessTokenKeyID
refresh_token_key_id: RefreshTokenKeyID
rules:
- name: do-not-use-public-schema-in-queries
message: "do not use public schema in queries"
+4
View File
@@ -42,6 +42,10 @@ const (
UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id);
UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt);
UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
UniqueMcpServerConfigsPkey UniqueConstraint = "mcp_server_configs_pkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
UniqueMcpServerConfigsSlugKey UniqueConstraint = "mcp_server_configs_slug_key" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
UniqueMcpServerUserTokensMcpServerConfigIDUserIDKey UniqueConstraint = "mcp_server_user_tokens_mcp_server_config_id_user_id_key" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
UniqueMcpServerUserTokensPkey UniqueConstraint = "mcp_server_user_tokens_pkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
UniqueNotificationPreferencesPkey UniqueConstraint = "notification_preferences_pkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_pkey PRIMARY KEY (user_id, notification_template_id);
UniqueNotificationReportGeneratorLogsPkey UniqueConstraint = "notification_report_generator_logs_pkey" // ALTER TABLE ONLY notification_report_generator_logs ADD CONSTRAINT notification_report_generator_logs_pkey PRIMARY KEY (notification_template_id);
+921
View File
@@ -0,0 +1,921 @@
package coderd
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/codersdk"
)
// @Summary List MCP server configs
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) listMCPServerConfigs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
// Admin users can see all MCP server configs (including disabled
// ones) for management purposes. Non-admin users see only enabled
// configs, which is sufficient for using the chat feature.
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
var configs []database.MCPServerConfig
var err error
if isAdmin {
configs, err = api.Database.GetMCPServerConfigs(ctx)
} else {
//nolint:gocritic // All authenticated users need to read enabled MCP server configs to use the chat feature.
configs, err = api.Database.GetEnabledMCPServerConfigs(dbauthz.AsSystemRestricted(ctx))
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list MCP server configs.",
Detail: err.Error(),
})
return
}
// Look up the calling user's OAuth2 tokens so we can populate
// auth_connected per server.
//nolint:gocritic // Need to check user tokens across all servers.
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get user tokens.",
Detail: err.Error(),
})
return
}
tokenMap := make(map[uuid.UUID]bool, len(userTokens))
for _, t := range userTokens {
tokenMap[t.MCPServerConfigID] = true
}
resp := make([]codersdk.MCPServerConfig, 0, len(configs))
for _, config := range configs {
var sdkConfig codersdk.MCPServerConfig
if isAdmin {
sdkConfig = convertMCPServerConfig(config)
} else {
sdkConfig = convertMCPServerConfigRedacted(config)
}
if config.AuthType == "oauth2" {
sdkConfig.AuthConnected = tokenMap[config.ID]
} else {
sdkConfig.AuthConnected = true
}
resp = append(resp, sdkConfig)
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// @Summary Create MCP server config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.CreateMCPServerConfigRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
// Validate auth-type-dependent fields.
switch req.AuthType {
case "oauth2":
if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.",
})
return
}
case "api_key":
if req.APIKeyHeader == "" || req.APIKeyValue == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key auth type requires api_key_header and api_key_value.",
})
return
}
case "custom_headers":
if len(req.CustomHeaders) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Custom headers auth type requires at least one custom header.",
})
return
}
}
customHeadersJSON, err := marshalCustomHeaders(req.CustomHeaders)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid custom headers.",
Detail: err.Error(),
})
return
}
inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: strings.TrimSpace(req.DisplayName),
Slug: strings.TrimSpace(req.Slug),
Description: strings.TrimSpace(req.Description),
IconURL: strings.TrimSpace(req.IconURL),
Transport: strings.TrimSpace(req.Transport),
Url: strings.TrimSpace(req.URL),
AuthType: strings.TrimSpace(req.AuthType),
OAuth2ClientID: strings.TrimSpace(req.OAuth2ClientID),
OAuth2ClientSecret: strings.TrimSpace(req.OAuth2ClientSecret),
OAuth2ClientSecretKeyID: sql.NullString{},
OAuth2AuthURL: strings.TrimSpace(req.OAuth2AuthURL),
OAuth2TokenURL: strings.TrimSpace(req.OAuth2TokenURL),
OAuth2Scopes: strings.TrimSpace(req.OAuth2Scopes),
APIKeyHeader: strings.TrimSpace(req.APIKeyHeader),
APIKeyValue: strings.TrimSpace(req.APIKeyValue),
APIKeyValueKeyID: sql.NullString{},
CustomHeaders: customHeadersJSON,
CustomHeadersKeyID: sql.NullString{},
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
Availability: strings.TrimSpace(req.Availability),
Enabled: req.Enabled,
CreatedBy: apiKey.UserID,
UpdatedBy: apiKey.UserID,
})
if err != nil {
switch {
case database.IsUniqueViolation(err):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "MCP server config already exists.",
Detail: err.Error(),
})
return
case database.IsCheckViolation(err):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid MCP server config.",
Detail: err.Error(),
})
return
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create MCP server config.",
Detail: err.Error(),
})
return
}
}
httpapi.Write(ctx, rw, http.StatusCreated, convertMCPServerConfig(inserted))
}
// @Summary Get MCP server config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) getMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
var config database.MCPServerConfig
var err error
if isAdmin {
config, err = api.Database.GetMCPServerConfigByID(ctx, mcpServerID)
} else {
//nolint:gocritic // All authenticated users can view enabled MCP server configs.
config, err = api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
if err == nil && !config.Enabled {
httpapi.ResourceNotFound(rw)
return
}
}
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get MCP server config.",
Detail: err.Error(),
})
return
}
var sdkConfig codersdk.MCPServerConfig
if isAdmin {
sdkConfig = convertMCPServerConfig(config)
} else {
sdkConfig = convertMCPServerConfigRedacted(config)
}
// Populate AuthConnected for the calling user.
if config.AuthType == "oauth2" {
//nolint:gocritic // Need to check user token for this server.
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get user tokens.",
Detail: err.Error(),
})
return
}
for _, t := range userTokens {
if t.MCPServerConfigID == config.ID {
sdkConfig.AuthConnected = true
break
}
}
} else {
sdkConfig.AuthConnected = true
}
httpapi.Write(ctx, rw, http.StatusOK, sdkConfig)
}
// @Summary Update MCP server config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
var req codersdk.UpdateMCPServerConfigRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
// Pre-validate custom headers before entering the transaction.
var customHeadersJSON string
if req.CustomHeaders != nil {
var chErr error
customHeadersJSON, chErr = marshalCustomHeaders(*req.CustomHeaders)
if chErr != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid custom headers.",
Detail: chErr.Error(),
})
return
}
}
var updated database.MCPServerConfig
err := api.Database.InTx(func(tx database.Store) error {
existing, err := tx.GetMCPServerConfigByID(ctx, mcpServerID)
if err != nil {
return err
}
displayName := existing.DisplayName
if req.DisplayName != nil {
displayName = strings.TrimSpace(*req.DisplayName)
}
slug := existing.Slug
if req.Slug != nil {
slug = strings.TrimSpace(*req.Slug)
}
description := existing.Description
if req.Description != nil {
description = strings.TrimSpace(*req.Description)
}
iconURL := existing.IconURL
if req.IconURL != nil {
iconURL = strings.TrimSpace(*req.IconURL)
}
transport := existing.Transport
if req.Transport != nil {
transport = strings.TrimSpace(*req.Transport)
}
serverURL := existing.Url
if req.URL != nil {
serverURL = strings.TrimSpace(*req.URL)
}
authType := existing.AuthType
if req.AuthType != nil {
authType = strings.TrimSpace(*req.AuthType)
}
oauth2ClientID := existing.OAuth2ClientID
if req.OAuth2ClientID != nil {
oauth2ClientID = strings.TrimSpace(*req.OAuth2ClientID)
}
oauth2ClientSecret := existing.OAuth2ClientSecret
oauth2ClientSecretKeyID := existing.OAuth2ClientSecretKeyID
if req.OAuth2ClientSecret != nil {
oauth2ClientSecret = strings.TrimSpace(*req.OAuth2ClientSecret)
// Clear the key ID when the secret is explicitly updated.
oauth2ClientSecretKeyID = sql.NullString{}
}
oauth2AuthURL := existing.OAuth2AuthURL
if req.OAuth2AuthURL != nil {
oauth2AuthURL = strings.TrimSpace(*req.OAuth2AuthURL)
}
oauth2TokenURL := existing.OAuth2TokenURL
if req.OAuth2TokenURL != nil {
oauth2TokenURL = strings.TrimSpace(*req.OAuth2TokenURL)
}
oauth2Scopes := existing.OAuth2Scopes
if req.OAuth2Scopes != nil {
oauth2Scopes = strings.TrimSpace(*req.OAuth2Scopes)
}
apiKeyHeader := existing.APIKeyHeader
if req.APIKeyHeader != nil {
apiKeyHeader = strings.TrimSpace(*req.APIKeyHeader)
}
apiKeyValue := existing.APIKeyValue
apiKeyValueKeyID := existing.APIKeyValueKeyID
if req.APIKeyValue != nil {
apiKeyValue = strings.TrimSpace(*req.APIKeyValue)
// Clear the key ID when the value is explicitly updated.
apiKeyValueKeyID = sql.NullString{}
}
customHeaders := existing.CustomHeaders
customHeadersKeyID := existing.CustomHeadersKeyID
if req.CustomHeaders != nil {
customHeaders = customHeadersJSON
// Clear the key ID when headers are explicitly updated.
customHeadersKeyID = sql.NullString{}
}
toolAllowList := existing.ToolAllowList
if req.ToolAllowList != nil {
toolAllowList = coalesceStringSlice(trimStringSlice(*req.ToolAllowList))
}
toolDenyList := existing.ToolDenyList
if req.ToolDenyList != nil {
toolDenyList = coalesceStringSlice(trimStringSlice(*req.ToolDenyList))
}
availability := existing.Availability
if req.Availability != nil {
availability = strings.TrimSpace(*req.Availability)
}
enabled := existing.Enabled
if req.Enabled != nil {
enabled = *req.Enabled
}
// When auth_type changes, clear fields belonging to the
// previous auth type so stale secrets don't persist.
if authType != existing.AuthType {
switch authType {
case "none":
oauth2ClientID = ""
oauth2ClientSecret = ""
oauth2ClientSecretKeyID = sql.NullString{}
oauth2AuthURL = ""
oauth2TokenURL = ""
oauth2Scopes = ""
apiKeyHeader = ""
apiKeyValue = ""
apiKeyValueKeyID = sql.NullString{}
customHeaders = "{}"
customHeadersKeyID = sql.NullString{}
case "oauth2":
apiKeyHeader = ""
apiKeyValue = ""
apiKeyValueKeyID = sql.NullString{}
customHeaders = "{}"
customHeadersKeyID = sql.NullString{}
case "api_key":
oauth2ClientID = ""
oauth2ClientSecret = ""
oauth2ClientSecretKeyID = sql.NullString{}
oauth2AuthURL = ""
oauth2TokenURL = ""
oauth2Scopes = ""
customHeaders = "{}"
customHeadersKeyID = sql.NullString{}
case "custom_headers":
oauth2ClientID = ""
oauth2ClientSecret = ""
oauth2ClientSecretKeyID = sql.NullString{}
oauth2AuthURL = ""
oauth2TokenURL = ""
oauth2Scopes = ""
apiKeyHeader = ""
apiKeyValue = ""
apiKeyValueKeyID = sql.NullString{}
}
}
updated, err = tx.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
DisplayName: displayName,
Slug: slug,
Description: description,
IconURL: iconURL,
Transport: transport,
Url: serverURL,
AuthType: authType,
OAuth2ClientID: oauth2ClientID,
OAuth2ClientSecret: oauth2ClientSecret,
OAuth2ClientSecretKeyID: oauth2ClientSecretKeyID,
OAuth2AuthURL: oauth2AuthURL,
OAuth2TokenURL: oauth2TokenURL,
OAuth2Scopes: oauth2Scopes,
APIKeyHeader: apiKeyHeader,
APIKeyValue: apiKeyValue,
APIKeyValueKeyID: apiKeyValueKeyID,
CustomHeaders: customHeaders,
CustomHeadersKeyID: customHeadersKeyID,
ToolAllowList: toolAllowList,
ToolDenyList: toolDenyList,
Availability: availability,
Enabled: enabled,
UpdatedBy: apiKey.UserID,
ID: existing.ID,
})
return err
}, nil)
if err != nil {
switch {
case httpapi.Is404Error(err):
httpapi.ResourceNotFound(rw)
return
case database.IsUniqueViolation(err):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "MCP server config slug already exists.",
Detail: err.Error(),
})
return
case database.IsCheckViolation(err):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid MCP server config.",
Detail: err.Error(),
})
return
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update MCP server config.",
Detail: err.Error(),
})
return
}
}
httpapi.Write(ctx, rw, http.StatusOK, convertMCPServerConfig(updated))
}
// @Summary Delete MCP server config
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) deleteMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
if _, err := api.Database.GetMCPServerConfigByID(ctx, mcpServerID); err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get MCP server config.",
Detail: err.Error(),
})
return
}
if err := api.Database.DeleteMCPServerConfigByID(ctx, mcpServerID); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete MCP server config.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Initiate MCP server OAuth2 connect
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
// Redirects the user to the MCP server's OAuth2 authorization URL.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) mcpServerOAuth2Connect(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
//nolint:gocritic // Any authenticated user can initiate OAuth2 for an enabled MCP server.
config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get MCP server config.",
Detail: err.Error(),
})
return
}
if !config.Enabled {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "MCP server is not enabled.",
})
return
}
if config.AuthType != "oauth2" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "MCP server does not use OAuth2 authentication.",
})
return
}
if config.OAuth2AuthURL == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "MCP server OAuth2 authorization URL is not configured.",
})
return
}
// Build the authorization URL. The frontend opens this in a popup.
// The callback URL is on our server; after the exchange we store
// the token and close the popup.
state := uuid.New().String()
http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
Name: "mcp_oauth2_state_" + config.ID.String(),
Value: state,
Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID),
MaxAge: 600, // 10 minutes
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}))
oauth2Config := &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.OAuth2AuthURL,
TokenURL: config.OAuth2TokenURL,
},
RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID),
}
var scopes []string
if config.OAuth2Scopes != "" {
scopes = strings.Split(config.OAuth2Scopes, " ")
}
oauth2Config.Scopes = scopes
authURL := oauth2Config.AuthCodeURL(state)
http.Redirect(rw, r, authURL, http.StatusTemporaryRedirect)
}
// @Summary Handle MCP server OAuth2 callback
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
// Exchanges the authorization code for tokens and stores them.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) mcpServerOAuth2Callback(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
//nolint:gocritic // Any authenticated user can complete OAuth2 for an enabled MCP server.
config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get MCP server config.",
Detail: err.Error(),
})
return
}
if !config.Enabled {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "MCP server is not enabled.",
})
return
}
if config.AuthType != "oauth2" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "MCP server does not use OAuth2 authentication.",
})
return
}
// Check if the OAuth2 provider returned an error (e.g., user
// denied consent).
if oauthError := r.URL.Query().Get("error"); oauthError != "" {
desc := r.URL.Query().Get("error_description")
if desc == "" {
desc = oauthError
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "OAuth2 provider returned an error.",
Detail: desc,
})
return
}
code := r.URL.Query().Get("code")
if code == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing authorization code.",
})
return
}
// Validate the state parameter for CSRF protection.
expectedState := ""
if cookie, err := r.Cookie("mcp_oauth2_state_" + config.ID.String()); err == nil {
expectedState = cookie.Value
}
actualState := r.URL.Query().Get("state")
if expectedState == "" || actualState != expectedState {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid or missing OAuth2 state parameter.",
})
return
}
// Clear the state cookie.
http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
Name: "mcp_oauth2_state_" + config.ID.String(),
Value: "",
Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID),
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}))
// Exchange the authorization code for tokens.
oauth2Config := &oauth2.Config{
ClientID: config.OAuth2ClientID,
ClientSecret: config.OAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.OAuth2AuthURL,
TokenURL: config.OAuth2TokenURL,
},
RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID),
}
var scopes []string
if config.OAuth2Scopes != "" {
scopes = strings.Split(config.OAuth2Scopes, " ")
}
oauth2Config.Scopes = scopes
// Use the deployment's HTTP client for the token exchange to
// respect proxy settings and avoid using http.DefaultClient.
exchangeCtx := context.WithValue(ctx, oauth2.HTTPClient, api.HTTPClient)
token, err := oauth2Config.Exchange(exchangeCtx, code)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadGateway, codersdk.Response{
Message: "Failed to exchange authorization code for token.",
Detail: "The OAuth2 token exchange with the upstream provider failed.",
})
return
}
// Store the token for the user.
refreshToken := ""
if token.RefreshToken != "" {
refreshToken = token.RefreshToken
}
var expiry sql.NullTime
if !token.Expiry.IsZero() {
expiry = sql.NullTime{Time: token.Expiry, Valid: true}
}
//nolint:gocritic // Users store their own tokens.
_, err = api.Database.UpsertMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: mcpServerID,
UserID: apiKey.UserID,
AccessToken: token.AccessToken,
AccessTokenKeyID: sql.NullString{},
RefreshToken: refreshToken,
RefreshTokenKeyID: sql.NullString{},
TokenType: token.TokenType,
Expiry: expiry,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to store OAuth2 token.",
Detail: err.Error(),
})
return
}
// Respond with a simple HTML page that closes the popup window.
rw.Header().Set("Content-Security-Policy", "default-src 'none'; script-src 'unsafe-inline'")
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte(`<!DOCTYPE html><html><body><script>
if (window.opener) {
window.opener.postMessage({type: "mcp-oauth2-complete", serverID: "` + config.ID.String() + `"}, "` + api.AccessURL.String() + `");
window.close();
} else {
document.body.innerText = "Authentication successful. You may close this window.";
}
</script></body></html>`))
}
// @Summary Disconnect MCP server OAuth2 token
// @x-apidocgen {"skip": true}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
// Removes the user's stored OAuth2 token for an MCP server.
func (api *API) mcpServerOAuth2Disconnect(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
mcpServerID, ok := parseMCPServerConfigID(rw, r)
if !ok {
return
}
//nolint:gocritic // Users manage their own tokens.
err := api.Database.DeleteMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.DeleteMCPServerUserTokenParams{
MCPServerConfigID: mcpServerID,
UserID: apiKey.UserID,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to disconnect OAuth2 token.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// parseMCPServerConfigID extracts the MCP server config UUID from the
// "mcpServer" path parameter.
func parseMCPServerConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
mcpServerID, err := uuid.Parse(chi.URLParam(r, "mcpServer"))
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid MCP server config ID.",
Detail: err.Error(),
})
return uuid.Nil, false
}
return mcpServerID, true
}
// convertMCPServerConfig converts a database MCP server config to the
// SDK type. Secrets are never returned; only has_* booleans are set.
// Admin-only fields (OAuth2 client ID, auth URLs, etc.) are included.
func convertMCPServerConfig(config database.MCPServerConfig) codersdk.MCPServerConfig {
return codersdk.MCPServerConfig{
ID: config.ID,
DisplayName: config.DisplayName,
Slug: config.Slug,
Description: config.Description,
IconURL: config.IconURL,
Transport: config.Transport,
URL: config.Url,
AuthType: config.AuthType,
OAuth2ClientID: config.OAuth2ClientID,
HasOAuth2Secret: config.OAuth2ClientSecret != "",
OAuth2AuthURL: config.OAuth2AuthURL,
OAuth2TokenURL: config.OAuth2TokenURL,
OAuth2Scopes: config.OAuth2Scopes,
APIKeyHeader: config.APIKeyHeader,
HasAPIKey: config.APIKeyValue != "",
HasCustomHeaders: len(config.CustomHeaders) > 0 && config.CustomHeaders != "{}",
ToolAllowList: coalesceStringSlice(config.ToolAllowList),
ToolDenyList: coalesceStringSlice(config.ToolDenyList),
Availability: config.Availability,
Enabled: config.Enabled,
CreatedAt: config.CreatedAt,
UpdatedAt: config.UpdatedAt,
}
}
// convertMCPServerConfigRedacted is the same as convertMCPServerConfig
// but strips admin-only fields (OAuth2 details, API key header) for
// non-admin callers.
func convertMCPServerConfigRedacted(config database.MCPServerConfig) codersdk.MCPServerConfig {
c := convertMCPServerConfig(config)
c.URL = ""
c.Transport = ""
c.OAuth2ClientID = ""
c.OAuth2AuthURL = ""
c.OAuth2TokenURL = ""
c.OAuth2Scopes = ""
c.APIKeyHeader = ""
return c
}
// marshalCustomHeaders encodes a map of custom headers to JSON for
// database storage. A nil map produces an empty JSON object.
func marshalCustomHeaders(headers map[string]string) (string, error) {
if headers == nil {
return "{}", nil
}
encoded, err := json.Marshal(headers)
if err != nil {
return "", err
}
return string(encoded), nil
}
// trimStringSlice trims whitespace from each element and drops empty
// strings.
func trimStringSlice(ss []string) []string {
if ss == nil {
return nil
}
out := make([]string, 0, len(ss))
for _, s := range ss {
if trimmed := strings.TrimSpace(s); trimmed != "" {
out = append(out, trimmed)
}
}
return out
}
// coalesceStringSlice returns ss if non-nil, otherwise an empty
// non-nil slice. This prevents pq.Array from sending NULL for
// NOT NULL text[] columns.
func coalesceStringSlice(ss []string) []string {
if ss == nil {
return []string{}
}
return ss
}
+489
View File
@@ -0,0 +1,489 @@
package coderd_test
import (
"encoding/json"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// mcpDeploymentValues returns deployment values with the agents
// experiment enabled, which is required by the MCP server config
// endpoints.
func mcpDeploymentValues(t testing.TB) *codersdk.DeploymentValues {
t.Helper()
values := coderdtest.DeploymentValues(t)
values.Experiments = []string{string(codersdk.ExperimentAgents)}
return values
}
// newMCPClient creates a test server with the agents experiment
// enabled and returns the admin client.
func newMCPClient(t testing.TB) *codersdk.Client {
t.Helper()
return coderdtest.New(t, &coderdtest.Options{
DeploymentValues: mcpDeploymentValues(t),
})
}
// createMCPServerConfig is a helper that creates a minimal enabled
// MCP server config with auth_type=none.
func createMCPServerConfig(t testing.TB, client *codersdk.Client, slug string, enabled bool) codersdk.MCPServerConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
config, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Test Server " + slug,
Slug: slug,
Description: "A test MCP server.",
IconURL: "https://example.com/icon.png",
Transport: "streamable_http",
URL: "https://mcp.example.com/" + slug,
AuthType: "none",
Availability: "default_on",
Enabled: enabled,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
return config
}
func TestMCPServerConfigsCRUD(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Create a config with all fields populated including OAuth2
// secrets so we can verify they are not leaked.
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "My MCP Server",
Slug: "my-mcp-server",
Description: "Integration test server.",
IconURL: "https://example.com/icon.png",
Transport: "streamable_http",
URL: "https://mcp.example.com/v1",
AuthType: "oauth2",
OAuth2ClientID: "client-id-123",
OAuth2ClientSecret: "super-secret-value",
OAuth2AuthURL: "https://auth.example.com/authorize",
OAuth2TokenURL: "https://auth.example.com/token",
OAuth2Scopes: "read write",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, created.ID)
require.Equal(t, "My MCP Server", created.DisplayName)
require.Equal(t, "my-mcp-server", created.Slug)
require.Equal(t, "Integration test server.", created.Description)
require.Equal(t, "streamable_http", created.Transport)
require.Equal(t, "https://mcp.example.com/v1", created.URL)
require.Equal(t, "oauth2", created.AuthType)
require.Equal(t, "client-id-123", created.OAuth2ClientID)
require.Equal(t, "default_on", created.Availability)
require.True(t, created.Enabled)
// Verify the secret is indicated but never returned.
require.True(t, created.HasOAuth2Secret)
// Verify the config appears in the list.
configs, err := client.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, configs, 1)
require.Equal(t, created.ID, configs[0].ID)
require.True(t, configs[0].HasOAuth2Secret)
// Update display name and availability.
newName := "Renamed Server"
newAvail := "force_on"
updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{
DisplayName: &newName,
Availability: &newAvail,
})
require.NoError(t, err)
require.Equal(t, "Renamed Server", updated.DisplayName)
require.Equal(t, "force_on", updated.Availability)
// Unchanged fields should remain the same.
require.Equal(t, "my-mcp-server", updated.Slug)
require.Equal(t, "oauth2", updated.AuthType)
// Verify the update took effect through the list.
configs, err = client.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, configs, 1)
require.Equal(t, "Renamed Server", configs[0].DisplayName)
require.Equal(t, "force_on", configs[0].Availability)
// Delete it.
err = client.DeleteMCPServerConfig(ctx, created.ID)
require.NoError(t, err)
// Verify it's gone.
configs, err = client.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Empty(t, configs)
}
func TestMCPServerConfigsNonAdmin(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newMCPClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
// Admin creates two configs: one enabled, one disabled.
_ = createMCPServerConfig(t, adminClient, "enabled-server", true)
_ = createMCPServerConfig(t, adminClient, "disabled-server", false)
// Admin sees both.
adminConfigs, err := adminClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, adminConfigs, 2)
// Regular user sees only the enabled one.
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, memberConfigs, 1)
require.Equal(t, "enabled-server", memberConfigs[0].Slug)
}
// TestMCPServerConfigsSecretsNeverLeaked is a load-bearing test that
// ensures secret fields (OAuth2 client secret, API key value, custom
// headers) are never present in API responses for any caller. If this
// test fails, it means a code change accidentally started exposing
// secrets. See: https://github.com/coder/coder/pull/23227#discussion_r2959461109
func TestMCPServerConfigsSecretsNeverLeaked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newMCPClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
// Create a config with ALL secret fields populated.
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Secrets Test",
Slug: "secrets-test",
Transport: "streamable_http",
URL: "https://mcp.example.com/secrets",
AuthType: "oauth2",
OAuth2ClientID: "client-id-secret-test",
OAuth2ClientSecret: "THIS-IS-A-SECRET-VALUE",
OAuth2AuthURL: "https://auth.example.com/authorize",
OAuth2TokenURL: "https://auth.example.com/token",
OAuth2Scopes: "read write",
APIKeyHeader: "X-Api-Key",
APIKeyValue: "THIS-IS-A-SECRET-API-KEY",
CustomHeaders: map[string]string{"X-Custom": "THIS-IS-A-SECRET-HEADER"},
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
// The sentinel values we must never see in any JSON response.
secrets := []string{
"THIS-IS-A-SECRET-VALUE",
"THIS-IS-A-SECRET-API-KEY",
"THIS-IS-A-SECRET-HEADER",
}
assertNoSecrets := func(t *testing.T, label string, v interface{}) {
t.Helper()
data, err := json.Marshal(v)
require.NoError(t, err)
jsonStr := string(data)
for _, secret := range secrets {
assert.False(t, strings.Contains(jsonStr, secret),
"%s: JSON response contains secret %q", label, secret)
}
}
// Verify the create response doesn't leak secrets.
assertNoSecrets(t, "admin create response", created)
// Verify boolean indicators are set correctly.
require.True(t, created.HasOAuth2Secret, "HasOAuth2Secret should be true")
require.True(t, created.HasAPIKey, "HasAPIKey should be true")
require.True(t, created.HasCustomHeaders, "HasCustomHeaders should be true")
// Admin list endpoint.
adminConfigs, err := adminClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.NotEmpty(t, adminConfigs)
for _, cfg := range adminConfigs {
assertNoSecrets(t, "admin list", cfg)
}
// Admin get-by-ID endpoint.
adminSingle, err := adminClient.MCPServerConfigByID(ctx, created.ID)
require.NoError(t, err)
assertNoSecrets(t, "admin get-by-id", adminSingle)
// Non-admin list endpoint.
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.NotEmpty(t, memberConfigs)
for _, cfg := range memberConfigs {
assertNoSecrets(t, "member list", cfg)
// Non-admin should also not see admin-only fields.
assert.Empty(t, cfg.OAuth2ClientID, "member should not see OAuth2ClientID")
assert.Empty(t, cfg.OAuth2AuthURL, "member should not see OAuth2AuthURL")
assert.Empty(t, cfg.OAuth2TokenURL, "member should not see OAuth2TokenURL")
assert.Empty(t, cfg.APIKeyHeader, "member should not see APIKeyHeader")
assert.Empty(t, cfg.OAuth2Scopes, "member should not see OAuth2Scopes")
assert.Empty(t, cfg.URL, "member should not see URL")
assert.Empty(t, cfg.Transport, "member should not see Transport")
}
// Non-admin get-by-ID endpoint.
memberSingle, err := memberClient.MCPServerConfigByID(ctx, created.ID)
require.NoError(t, err)
assertNoSecrets(t, "member get-by-id", memberSingle)
assert.Empty(t, memberSingle.OAuth2ClientID, "member should not see OAuth2ClientID")
assert.Empty(t, memberSingle.OAuth2AuthURL, "member should not see OAuth2AuthURL")
assert.Empty(t, memberSingle.OAuth2TokenURL, "member should not see OAuth2TokenURL")
assert.Empty(t, memberSingle.OAuth2Scopes, "member should not see OAuth2Scopes")
assert.Empty(t, memberSingle.APIKeyHeader, "member should not see APIKeyHeader")
assert.Empty(t, memberSingle.URL, "member should not see URL")
assert.Empty(t, memberSingle.Transport, "member should not see Transport")
}
func TestMCPServerConfigsAuthConnected(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newMCPClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
// Create an oauth2 server config (enabled).
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "OAuth Server",
Slug: "oauth-server",
Transport: "streamable_http",
URL: "https://mcp.example.com/oauth",
AuthType: "oauth2",
OAuth2ClientID: "cid",
OAuth2AuthURL: "https://auth.example.com/authorize",
OAuth2TokenURL: "https://auth.example.com/token",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
// Regular user lists configs — auth_connected should be false
// because no token has been stored.
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, memberConfigs, 1)
require.Equal(t, created.ID, memberConfigs[0].ID)
require.False(t, memberConfigs[0].AuthConnected)
// Also create a non-oauth server. It should report
// auth_connected=true because no auth is needed.
_ = createMCPServerConfig(t, adminClient, "no-auth-server", true)
memberConfigs, err = memberClient.MCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, memberConfigs, 2)
for _, cfg := range memberConfigs {
if cfg.AuthType == "none" {
require.True(t, cfg.AuthConnected)
} else {
require.False(t, cfg.AuthConnected)
}
}
}
func TestMCPServerConfigsAvailability(t *testing.T) {
t.Parallel()
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
validValues := []string{"force_on", "default_on", "default_off"}
for _, av := range validValues {
av := av
t.Run(av, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Server " + av,
Slug: "server-" + av,
Transport: "streamable_http",
URL: "https://mcp.example.com/" + av,
AuthType: "none",
Availability: av,
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
require.Equal(t, av, created.Availability)
})
}
t.Run("InvalidAvailability", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Bad Availability",
Slug: "bad-avail",
Transport: "streamable_http",
URL: "https://mcp.example.com/bad",
AuthType: "none",
Availability: "always_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
func TestMCPServerConfigsUniqueSlug(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "First",
Slug: "test-server",
Transport: "streamable_http",
URL: "https://mcp.example.com/first",
AuthType: "none",
Availability: "default_off",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
// Attempt to create another config with the same slug.
_, err = client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "Second",
Slug: "test-server",
Transport: "streamable_http",
URL: "https://mcp.example.com/second",
AuthType: "none",
Availability: "default_off",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
}
func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newMCPClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
DisplayName: "OAuth Disconnect Test",
Slug: "oauth-disconnect",
Transport: "streamable_http",
URL: "https://mcp.example.com/oauth-disc",
AuthType: "oauth2",
OAuth2ClientID: "cid",
OAuth2AuthURL: "https://auth.example.com/authorize",
OAuth2TokenURL: "https://auth.example.com/token",
Availability: "default_on",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
})
require.NoError(t, err)
// Disconnect should succeed even when no token exists (idempotent).
err = memberClient.MCPServerOAuth2Disconnect(ctx, created.ID)
require.NoError(t, err)
}
func TestChatWithMCPServerIDs(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newMCPClient(t)
_ = coderdtest.CreateFirstUser(t, client)
// Create the chat model config required for creating a chat.
_ = createChatModelConfigForMCP(t, client)
// Create an enabled MCP server config.
mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true)
// Create a chat referencing the MCP server.
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello with mcp server",
},
},
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
})
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, chat.ID)
require.Contains(t, chat.MCPServerIDs, mcpConfig.ID)
// Fetch the chat and verify the MCP server IDs persist.
fetched, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID)
}
// createChatModelConfigForMCP sets up a chat provider and model
// config so that CreateChat succeeds. This mirrors the helper in
// chats_test.go but is defined here to avoid coupling.
func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
return modelConfig
}
+1 -1
View File
@@ -41,7 +41,7 @@ func (api *API) PrimaryRegion(ctx context.Context) (codersdk.Region, error) {
ID: deploymentID,
Name: "primary",
DisplayName: proxy.DisplayName,
IconURL: proxy.IconUrl,
IconURL: proxy.IconURL,
Healthy: true,
PathAppURL: api.AccessURL.String(),
WildcardHostname: appurl.SubdomainAppHost(api.AppHostname, api.AccessURL),
+3
View File
@@ -49,6 +49,7 @@ type Chat struct {
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
Archived bool `json:"archived"`
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
}
// ChatMessage represents a single message in a chat.
@@ -267,6 +268,7 @@ type CreateChatRequest struct {
Content []ChatInputPart `json:"content"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
}
// UpdateChatRequest is the request to update a chat.
@@ -279,6 +281,7 @@ type UpdateChatRequest struct {
type CreateChatMessageRequest struct {
Content []ChatInputPart `json:"content"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
}
// EditChatMessageRequest is the request to edit a user message in a chat.
+191
View File
@@ -0,0 +1,191 @@
package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
// MCPServerOAuth2ConnectURL returns the URL the user should visit to
// start the OAuth2 flow for an MCP server. The frontend opens this
// in a new window/popup.
func (c *Client) MCPServerOAuth2ConnectURL(id uuid.UUID) string {
return fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/connect", c.URL.String(), id)
}
// MCPServerOAuth2Disconnect removes the user's OAuth2 token for an
// MCP server.
func (c *Client) MCPServerOAuth2Disconnect(ctx context.Context, id uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/disconnect", id), nil)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// MCPServerConfig represents an admin-configured MCP server.
type MCPServerConfig struct {
ID uuid.UUID `json:"id" format:"uuid"`
DisplayName string `json:"display_name"`
Slug string `json:"slug"`
Description string `json:"description"`
IconURL string `json:"icon_url"`
Transport string `json:"transport"` // "streamable_http" or "sse"
URL string `json:"url"`
AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers"
// OAuth2 fields (only populated for admins).
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
HasOAuth2Secret bool `json:"has_oauth2_secret"`
OAuth2AuthURL string `json:"oauth2_auth_url,omitempty"`
OAuth2TokenURL string `json:"oauth2_token_url,omitempty"`
OAuth2Scopes string `json:"oauth2_scopes,omitempty"`
// API key fields (only populated for admins).
APIKeyHeader string `json:"api_key_header,omitempty"`
HasAPIKey bool `json:"has_api_key"`
HasCustomHeaders bool `json:"has_custom_headers"`
// Tool governance.
ToolAllowList []string `json:"tool_allow_list"`
ToolDenyList []string `json:"tool_deny_list"`
// Availability policy set by admin.
Availability string `json:"availability"` // "force_on", "default_on", "default_off"
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
// Per-user state (populated for non-admin requests).
AuthConnected bool `json:"auth_connected"`
}
// CreateMCPServerConfigRequest is the request to create a new MCP server config.
type CreateMCPServerConfigRequest struct {
DisplayName string `json:"display_name" validate:"required"`
Slug string `json:"slug" validate:"required"`
Description string `json:"description"`
IconURL string `json:"icon_url"`
Transport string `json:"transport" validate:"required,oneof=streamable_http sse"`
URL string `json:"url" validate:"required,url"`
AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers"`
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
OAuth2ClientSecret string `json:"oauth2_client_secret,omitempty"`
OAuth2AuthURL string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
OAuth2TokenURL string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"`
OAuth2Scopes string `json:"oauth2_scopes,omitempty"`
APIKeyHeader string `json:"api_key_header,omitempty"`
APIKeyValue string `json:"api_key_value,omitempty"`
CustomHeaders map[string]string `json:"custom_headers,omitempty"`
ToolAllowList []string `json:"tool_allow_list,omitempty"`
ToolDenyList []string `json:"tool_deny_list,omitempty"`
Availability string `json:"availability" validate:"required,oneof=force_on default_on default_off"`
Enabled bool `json:"enabled"`
}
// UpdateMCPServerConfigRequest is the request to update an MCP server config.
type UpdateMCPServerConfigRequest struct {
DisplayName *string `json:"display_name,omitempty"`
Slug *string `json:"slug,omitempty"`
Description *string `json:"description,omitempty"`
IconURL *string `json:"icon_url,omitempty"`
Transport *string `json:"transport,omitempty" validate:"omitempty,oneof=streamable_http sse"`
URL *string `json:"url,omitempty" validate:"omitempty,url"`
AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers"`
OAuth2ClientID *string `json:"oauth2_client_id,omitempty"`
OAuth2ClientSecret *string `json:"oauth2_client_secret,omitempty"`
OAuth2AuthURL *string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
OAuth2TokenURL *string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"`
OAuth2Scopes *string `json:"oauth2_scopes,omitempty"`
APIKeyHeader *string `json:"api_key_header,omitempty"`
APIKeyValue *string `json:"api_key_value,omitempty"`
CustomHeaders *map[string]string `json:"custom_headers,omitempty"`
ToolAllowList *[]string `json:"tool_allow_list,omitempty"`
ToolDenyList *[]string `json:"tool_deny_list,omitempty"`
Availability *string `json:"availability,omitempty" validate:"omitempty,oneof=force_on default_on default_off"`
Enabled *bool `json:"enabled,omitempty"`
}
func (c *Client) MCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/mcp/servers", nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var configs []MCPServerConfig
return configs, json.NewDecoder(res.Body).Decode(&configs)
}
func (c *Client) MCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil)
if err != nil {
return MCPServerConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return MCPServerConfig{}, ReadBodyAsError(res)
}
var config MCPServerConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
func (c *Client) CreateMCPServerConfig(ctx context.Context, req CreateMCPServerConfigRequest) (MCPServerConfig, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/mcp/servers", req)
if err != nil {
return MCPServerConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return MCPServerConfig{}, ReadBodyAsError(res)
}
var config MCPServerConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
func (c *Client) UpdateMCPServerConfig(ctx context.Context, id uuid.UUID, req UpdateMCPServerConfigRequest) (MCPServerConfig, error) {
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), req)
if err != nil {
return MCPServerConfig{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return MCPServerConfig{}, ReadBodyAsError(res)
}
var config MCPServerConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
func (c *Client) DeleteMCPServerConfig(ctx context.Context, id uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
+2 -2
View File
@@ -204,7 +204,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw
args := database.UpsertDefaultProxyParams{
DisplayName: req.DisplayName,
IconUrl: req.Icon,
IconURL: req.Icon,
}
if req.DisplayName == "" || req.Icon == "" {
// If the user has not specified an update value, use the existing value.
@@ -217,7 +217,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw
args.DisplayName = existing.DisplayName
}
if req.Icon == "" {
args.IconUrl = existing.IconUrl
args.IconURL = existing.IconURL
}
}
+195
View File
@@ -471,6 +471,201 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat
return provider, nil
}
// decryptMCPServerConfig decrypts all encrypted fields on a
// single MCPServerConfig in place.
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
if err := db.decryptField(&cfg.OAuth2ClientSecret, cfg.OAuth2ClientSecretKeyID); err != nil {
return err
}
if err := db.decryptField(&cfg.APIKeyValue, cfg.APIKeyValueKeyID); err != nil {
return err
}
return db.decryptField(&cfg.CustomHeaders, cfg.CustomHeadersKeyID)
}
// decryptMCPServerUserToken decrypts all encrypted fields on a
// single MCPServerUserToken in place.
func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) error {
if err := db.decryptField(&tok.AccessToken, tok.AccessTokenKeyID); err != nil {
return err
}
return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID)
}
func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
cfg, err := db.Store.GetMCPServerConfigByID(ctx, id)
if err != nil {
return database.MCPServerConfig{}, err
}
if err := db.decryptMCPServerConfig(&cfg); err != nil {
return database.MCPServerConfig{}, err
}
return cfg, nil
}
func (db *dbCrypt) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
cfg, err := db.Store.GetMCPServerConfigBySlug(ctx, slug)
if err != nil {
return database.MCPServerConfig{}, err
}
if err := db.decryptMCPServerConfig(&cfg); err != nil {
return database.MCPServerConfig{}, err
}
return cfg, nil
}
func (db *dbCrypt) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
cfgs, err := db.Store.GetMCPServerConfigs(ctx)
if err != nil {
return nil, err
}
for i := range cfgs {
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
return nil, err
}
}
return cfgs, nil
}
func (db *dbCrypt) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
cfgs, err := db.Store.GetMCPServerConfigsByIDs(ctx, ids)
if err != nil {
return nil, err
}
for i := range cfgs {
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
return nil, err
}
}
return cfgs, nil
}
func (db *dbCrypt) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
cfgs, err := db.Store.GetEnabledMCPServerConfigs(ctx)
if err != nil {
return nil, err
}
for i := range cfgs {
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
return nil, err
}
}
return cfgs, nil
}
func (db *dbCrypt) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
cfgs, err := db.Store.GetForcedMCPServerConfigs(ctx)
if err != nil {
return nil, err
}
for i := range cfgs {
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
return nil, err
}
}
return cfgs, nil
}
func (db *dbCrypt) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
tok, err := db.Store.GetMCPServerUserToken(ctx, arg)
if err != nil {
return database.MCPServerUserToken{}, err
}
if err := db.decryptMCPServerUserToken(&tok); err != nil {
return database.MCPServerUserToken{}, err
}
return tok, nil
}
func (db *dbCrypt) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
toks, err := db.Store.GetMCPServerUserTokensByUserID(ctx, userID)
if err != nil {
return nil, err
}
for i := range toks {
if err := db.decryptMCPServerUserToken(&toks[i]); err != nil {
return nil, err
}
}
return toks, nil
}
func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
params.OAuth2ClientSecretKeyID = sql.NullString{}
} else if err := db.encryptField(&params.OAuth2ClientSecret, &params.OAuth2ClientSecretKeyID); err != nil {
return database.MCPServerConfig{}, err
}
if strings.TrimSpace(params.APIKeyValue) == "" {
params.APIKeyValueKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKeyValue, &params.APIKeyValueKeyID); err != nil {
return database.MCPServerConfig{}, err
}
if strings.TrimSpace(params.CustomHeaders) == "" {
params.CustomHeadersKeyID = sql.NullString{}
} else if err := db.encryptField(&params.CustomHeaders, &params.CustomHeadersKeyID); err != nil {
return database.MCPServerConfig{}, err
}
cfg, err := db.Store.InsertMCPServerConfig(ctx, params)
if err != nil {
return database.MCPServerConfig{}, err
}
if err := db.decryptMCPServerConfig(&cfg); err != nil {
return database.MCPServerConfig{}, err
}
return cfg, nil
}
func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
params.OAuth2ClientSecretKeyID = sql.NullString{}
} else if err := db.encryptField(&params.OAuth2ClientSecret, &params.OAuth2ClientSecretKeyID); err != nil {
return database.MCPServerConfig{}, err
}
if strings.TrimSpace(params.APIKeyValue) == "" {
params.APIKeyValueKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKeyValue, &params.APIKeyValueKeyID); err != nil {
return database.MCPServerConfig{}, err
}
if strings.TrimSpace(params.CustomHeaders) == "" {
params.CustomHeadersKeyID = sql.NullString{}
} else if err := db.encryptField(&params.CustomHeaders, &params.CustomHeadersKeyID); err != nil {
return database.MCPServerConfig{}, err
}
cfg, err := db.Store.UpdateMCPServerConfig(ctx, params)
if err != nil {
return database.MCPServerConfig{}, err
}
if err := db.decryptMCPServerConfig(&cfg); err != nil {
return database.MCPServerConfig{}, err
}
return cfg, nil
}
func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if strings.TrimSpace(params.AccessToken) == "" {
params.AccessTokenKeyID = sql.NullString{}
} else if err := db.encryptField(&params.AccessToken, &params.AccessTokenKeyID); err != nil {
return database.MCPServerUserToken{}, err
}
if strings.TrimSpace(params.RefreshToken) == "" {
params.RefreshTokenKeyID = sql.NullString{}
} else if err := db.encryptField(&params.RefreshToken, &params.RefreshTokenKeyID); err != nil {
return database.MCPServerUserToken{}, err
}
tok, err := db.Store.UpsertMCPServerUserToken(ctx, params)
if err != nil {
return database.MCPServerUserToken{}, err
}
if err := db.decryptMCPServerUserToken(&tok); err != nil {
return database.MCPServerUserToken{}, err
}
return tok, nil
}
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
+299
View File
@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -878,3 +879,301 @@ func fakeBase64RandomData(t *testing.T, n int) string {
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}
// requireMCPServerConfigDecrypted verifies all encrypted fields on an
// MCPServerConfig match the expected plaintext values and carry the
// correct key-ID.
func requireMCPServerConfigDecrypted(
t *testing.T,
cfg database.MCPServerConfig,
ciphers []Cipher,
wantSecret, wantAPIKey, wantHeaders string,
) {
t.Helper()
require.Equal(t, wantSecret, cfg.OAuth2ClientSecret)
require.Equal(t, wantAPIKey, cfg.APIKeyValue)
require.Equal(t, wantHeaders, cfg.CustomHeaders)
require.Equal(t, ciphers[0].HexDigest(), cfg.OAuth2ClientSecretKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), cfg.APIKeyValueKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), cfg.CustomHeadersKeyID.String)
}
// requireMCPServerConfigRawEncrypted reads the config from the raw
// (unwrapped) store and asserts every secret field is encrypted.
func requireMCPServerConfigRawEncrypted(
ctx context.Context,
t *testing.T,
rawDB database.Store,
cfgID uuid.UUID,
ciphers []Cipher,
wantSecret, wantAPIKey, wantHeaders string,
) {
t.Helper()
raw, err := rawDB.GetMCPServerConfigByID(ctx, cfgID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], raw.OAuth2ClientSecret, wantSecret)
requireEncryptedEquals(t, ciphers[0], raw.APIKeyValue, wantAPIKey)
requireEncryptedEquals(t, ciphers[0], raw.CustomHeaders, wantHeaders)
}
func TestMCPServerConfigs(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
//nolint:gosec // test credentials
oauthSecret = "my-oauth-secret"
apiKeyValue = "my-api-key"
customHeaders = `{"X-Custom":"header-value"}`
)
// insertConfig is a small helper that creates a user and an MCP
// server config through the encrypted store, returning both.
insertConfig := func(t *testing.T, crypt *dbCrypt, ciphers []Cipher) database.MCPServerConfig {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: "Test MCP Server",
Slug: "test-mcp-" + uuid.New().String()[:8],
Description: "test description",
Url: "https://mcp.example.com",
Transport: "streamable_http",
AuthType: "oauth2",
OAuth2ClientID: "client-id",
OAuth2ClientSecret: oauthSecret,
APIKeyValue: apiKeyValue,
CustomHeaders: customHeaders,
ToolAllowList: []string{},
ToolDenyList: []string{},
Availability: "force_on",
Enabled: true,
CreatedBy: user.ID,
UpdatedBy: user.ID,
})
require.NoError(t, err)
requireMCPServerConfigDecrypted(t, cfg, ciphers, oauthSecret, apiKeyValue, customHeaders)
return cfg
}
t.Run("InsertMCPServerConfig", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetMCPServerConfigByID", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
got, err := crypt.GetMCPServerConfigByID(ctx, cfg.ID)
require.NoError(t, err)
requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetMCPServerConfigBySlug", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
got, err := crypt.GetMCPServerConfigBySlug(ctx, cfg.Slug)
require.NoError(t, err)
requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetMCPServerConfigs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
cfgs, err := crypt.GetMCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, cfgs, 1)
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetMCPServerConfigsByIDs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
cfgs, err := crypt.GetMCPServerConfigsByIDs(ctx, []uuid.UUID{cfg.ID})
require.NoError(t, err)
require.Len(t, cfgs, 1)
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetEnabledMCPServerConfigs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
cfgs, err := crypt.GetEnabledMCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, cfgs, 1)
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("GetForcedMCPServerConfigs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
cfgs, err := crypt.GetForcedMCPServerConfigs(ctx)
require.NoError(t, err)
require.Len(t, cfgs, 1)
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
})
t.Run("UpdateMCPServerConfig", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg := insertConfig(t, crypt, ciphers)
const (
//nolint:gosec // test credential
newSecret = "updated-oauth-secret"
newAPIKey = "updated-api-key"
newHeaders = `{"X-New":"new-value"}`
)
updated, err := crypt.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
ID: cfg.ID,
DisplayName: cfg.DisplayName,
Slug: cfg.Slug,
Description: cfg.Description,
Url: cfg.Url,
Transport: cfg.Transport,
AuthType: cfg.AuthType,
OAuth2ClientID: cfg.OAuth2ClientID,
OAuth2ClientSecret: newSecret,
APIKeyValue: newAPIKey,
CustomHeaders: newHeaders,
ToolAllowList: cfg.ToolAllowList,
ToolDenyList: cfg.ToolDenyList,
Availability: cfg.Availability,
Enabled: cfg.Enabled,
UpdatedBy: cfg.CreatedBy.UUID,
})
require.NoError(t, err)
requireMCPServerConfigDecrypted(t, updated, ciphers, newSecret, newAPIKey, newHeaders)
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, newSecret, newAPIKey, newHeaders)
})
}
func TestMCPServerUserTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
accessToken = "access-token-value"
refreshToken = "refresh-token-value"
)
// insertConfigAndToken creates a user, an MCP server config, and a
// user token through the encrypted store.
insertConfigAndToken := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) (database.MCPServerConfig, database.MCPServerUserToken) {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: "Token Test MCP",
Slug: "tok-mcp-" + uuid.New().String()[:8],
Url: "https://mcp.example.com",
Transport: "streamable_http",
AuthType: "oauth2",
ToolAllowList: []string{},
ToolDenyList: []string{},
Availability: "default_off",
Enabled: true,
CreatedBy: user.ID,
UpdatedBy: user.ID,
})
require.NoError(t, err)
tok, err := crypt.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: cfg.ID,
UserID: user.ID,
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
})
require.NoError(t, err)
require.Equal(t, accessToken, tok.AccessToken)
require.Equal(t, refreshToken, tok.RefreshToken)
require.Equal(t, ciphers[0].HexDigest(), tok.AccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), tok.RefreshTokenKeyID.String)
return cfg, tok
}
t.Run("UpsertMCPServerUserToken", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
// Verify the raw DB values are encrypted.
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: cfg.ID,
UserID: tok.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
})
t.Run("GetMCPServerUserToken", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
got, err := crypt.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: cfg.ID,
UserID: tok.UserID,
})
require.NoError(t, err)
require.Equal(t, accessToken, got.AccessToken)
require.Equal(t, refreshToken, got.RefreshToken)
require.Equal(t, ciphers[0].HexDigest(), got.AccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), got.RefreshTokenKeyID.String)
// Raw values must be encrypted.
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: cfg.ID,
UserID: tok.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
})
t.Run("GetMCPServerUserTokensByUserID", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
toks, err := crypt.GetMCPServerUserTokensByUserID(ctx, tok.UserID)
require.NoError(t, err)
require.Len(t, toks, 1)
require.Equal(t, accessToken, toks[0].AccessToken)
require.Equal(t, refreshToken, toks[0].RefreshToken)
require.Equal(t, ciphers[0].HexDigest(), toks[0].AccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), toks[0].RefreshTokenKeyID.String)
// Raw values must be encrypted.
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: cfg.ID,
UserID: tok.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
})
}
+1
View File
@@ -73,6 +73,7 @@ const makeChat = (
id,
owner_id: "owner-1",
last_model_config_id: "model-1",
mcp_server_ids: [],
title: `Chat ${id}`,
status: "running",
created_at: "2025-01-01T00:00:00.000Z",
+100
View File
@@ -1070,6 +1070,7 @@ export interface Chat {
readonly created_at: string;
readonly updated_at: string;
readonly archived: boolean;
readonly mcp_server_ids: readonly string[];
}
// From codersdk/deployment.go
@@ -2076,6 +2077,7 @@ export interface ConvertLoginRequest {
export interface CreateChatMessageRequest {
readonly content: readonly ChatInputPart[];
readonly model_config_id?: string;
readonly mcp_server_ids?: string[];
}
// From codersdk/chats.go
@@ -2123,6 +2125,7 @@ export interface CreateChatRequest {
readonly content: readonly ChatInputPart[];
readonly workspace_id?: string;
readonly model_config_id?: string;
readonly mcp_server_ids?: readonly string[];
}
// From codersdk/users.go
@@ -2163,6 +2166,32 @@ export interface CreateGroupRequest {
readonly quota_allowance: number;
}
// From codersdk/mcp.go
/**
* CreateMCPServerConfigRequest is the request to create a new MCP server config.
*/
export interface CreateMCPServerConfigRequest {
readonly display_name: string;
readonly slug: string;
readonly description: string;
readonly icon_url: string;
readonly transport: string;
readonly url: string;
readonly auth_type: string;
readonly oauth2_client_id?: string;
readonly oauth2_client_secret?: string;
readonly oauth2_auth_url?: string;
readonly oauth2_token_url?: string;
readonly oauth2_scopes?: string;
readonly api_key_header?: string;
readonly api_key_value?: string;
readonly custom_headers?: Record<string, string>;
readonly tool_allow_list?: readonly string[];
readonly tool_deny_list?: readonly string[];
readonly availability: string;
readonly enabled: boolean;
}
// From codersdk/organizations.go
export interface CreateOrganizationRequest {
readonly name: string;
@@ -3692,6 +3721,51 @@ export interface LoginWithPasswordResponse {
readonly session_token: string;
}
// From codersdk/mcp.go
/**
* MCPServerConfig represents an admin-configured MCP server.
*/
export interface MCPServerConfig {
readonly id: string;
readonly display_name: string;
readonly slug: string;
readonly description: string;
readonly icon_url: string;
readonly transport: string; // "streamable_http" or "sse"
readonly url: string;
readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers"
/**
* OAuth2 fields (only populated for admins).
*/
readonly oauth2_client_id?: string;
readonly has_oauth2_secret: boolean;
readonly oauth2_auth_url?: string;
readonly oauth2_token_url?: string;
readonly oauth2_scopes?: string;
/**
* API key fields (only populated for admins).
*/
readonly api_key_header?: string;
readonly has_api_key: boolean;
readonly has_custom_headers: boolean;
/**
* Tool governance.
*/
readonly tool_allow_list: readonly string[];
readonly tool_deny_list: readonly string[];
/**
* Availability policy set by admin.
*/
readonly availability: string; // "force_on", "default_on", "default_off"
readonly enabled: boolean;
readonly created_at: string;
readonly updated_at: string;
/**
* Per-user state (populated for non-admin requests).
*/
readonly auth_connected: boolean;
}
// From codersdk/provisionerdaemons.go
/**
* MatchedProvisioners represents the number of provisioner daemons
@@ -6844,6 +6918,32 @@ export interface UpdateInboxNotificationReadStatusResponse {
readonly unread_count: number;
}
// From codersdk/mcp.go
/**
* UpdateMCPServerConfigRequest is the request to update an MCP server config.
*/
export interface UpdateMCPServerConfigRequest {
readonly display_name?: string;
readonly slug?: string;
readonly description?: string;
readonly icon_url?: string;
readonly transport?: string;
readonly url?: string;
readonly auth_type?: string;
readonly oauth2_client_id?: string;
readonly oauth2_client_secret?: string;
readonly oauth2_auth_url?: string;
readonly oauth2_token_url?: string;
readonly oauth2_scopes?: string;
readonly api_key_header?: string;
readonly api_key_value?: string;
readonly custom_headers?: Record<string, string>;
readonly tool_allow_list?: string[];
readonly tool_deny_list?: string[];
readonly availability?: string;
readonly enabled?: boolean;
}
// From codersdk/notifications.go
export interface UpdateNotificationTemplateMethod {
readonly method?: string;
@@ -106,6 +106,7 @@ const baseChatFields = {
owner_id: "owner-id",
workspace_id: mockWorkspace.id,
last_model_config_id: "model-config-1",
mcp_server_ids: [],
created_at: "2026-02-18T00:00:00.000Z",
updated_at: "2026-02-18T00:00:00.000Z",
archived: false,
@@ -183,6 +183,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({
id: chatID,
owner_id: "owner-1",
last_model_config_id: "model-1",
mcp_server_ids: [],
title: "test",
status: "running",
created_at: "2025-01-01T00:00:00.000Z",
@@ -52,6 +52,7 @@ export const WithParentChat: Story = {
id: "parent-chat-1",
owner_id: "owner-id",
last_model_config_id: "model-config-1",
mcp_server_ids: [],
title: "Set up CI/CD pipeline",
status: "completed",
last_error: null,
@@ -36,6 +36,7 @@ const buildChat = (overrides: Partial<TypesGen.Chat> = {}): TypesGen.Chat => ({
title: "Help me refactor",
status: "completed",
last_model_config_id: "model-config-1",
mcp_server_ids: [],
created_at: oneWeekAgo,
updated_at: oneWeekAgo,
archived: false,
@@ -113,6 +113,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
title: "Agent",
status: "completed",
last_model_config_id: defaultModelConfigs[0].id,
mcp_server_ids: [],
created_at: oneWeekAgo,
updated_at: oneWeekAgo,
archived: false,
@@ -40,6 +40,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
title: "Agent",
status: "completed",
last_model_config_id: defaultModelConfigs[0].id,
mcp_server_ids: [],
created_at: oneWeekAgo,
updated_at: oneWeekAgo,
archived: false,
@@ -63,6 +63,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
updated_at: oneWeekAgo,
archived: false,
last_error: null,
mcp_server_ids: [],
...overrides,
});