From d8ff67fb689c5b12412f5d73bf3e27441b3b37f2 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 19 Mar 2026 10:07:36 -0400 Subject: [PATCH] feat: add MCP server configuration backend for chats (#23227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds the database schema, API endpoints, SDK types, and encryption wrappers for admin-managed MCP (Model Context Protocol) server configurations that chatd can consume. This is the backend foundation for allowing external MCP tools (Sentry, Linear, GitHub, etc.) to be used during AI chat sessions. ## Database Two new tables: - **`mcp_server_configs`**: Admin-managed server definitions with URL, transport (Streamable HTTP / SSE), auth config (none / OAuth2 / API key / custom headers), tool allow/deny lists, and an availability policy (`force_on` / `default_on` / `default_off`). Includes CHECK constraints on transport, auth_type, and availability values. - **`mcp_server_user_tokens`**: Per-user OAuth2 tokens for servers requiring individual authentication. Cascades on user/config deletion. New column on `chats` table: - **`mcp_server_ids UUID[]`**: Per-chat MCP server selection, following the same pattern as `model_config_id` — passed at chat creation, changeable per-message with nil-means-no-change semantics. ## API Endpoints All routes are under `/api/experimental/mcp/servers/` and gated behind the `agents` experiment. **Admin endpoints** (`ResourceDeploymentConfig` auth): - `POST /` — Create MCP server config - `PATCH /{id}` — Update MCP server config (full-replace) - `DELETE /{id}` — Delete MCP server config **Authenticated endpoints** (all users, enabled servers only for non-admins): - `GET /` — List configs (admins see all, members see enabled-only with admin fields redacted) - `GET /{id}` — Get config by ID (with `auth_connected` populated per-user) **OAuth2 per-user auth flow:** - `GET /{id}/oauth2/connect` — Initiate OAuth2 flow (state cookie CSRF protection) - `GET /{id}/oauth2/callback` — Handle OAuth2 callback, store tokens - `DELETE /{id}/oauth2/disconnect` — Remove stored OAuth2 tokens ## Security - **Secrets never returned**: `OAuth2ClientSecret`, `APIKeyValue`, and `CustomHeaders` are never in API responses — only boolean indicators (`has_oauth2_secret`, `has_api_key`, `has_custom_headers`). - **Field redaction for non-admins**: `convertMCPServerConfigRedacted` strips `OAuth2ClientID`, auth URLs, scopes, and `APIKeyHeader` from non-admin responses. - **dbcrypt encryption at rest**: All 5 secret fields use `dbcrypt_keys` encryption with full encrypt-on-write / decrypt-on-read wrappers (11 dbcrypt method overrides + 2 helpers), following the same pattern as `chat_providers.api_key`. - **OAuth2 CSRF protection**: State parameter stored in `HttpOnly` cookie with `HTTPCookies.Apply()` for correct `Secure`/`SameSite` behind TLS-terminating proxies. - **dbauthz authorization**: All 18 querier methods have authorization wrappers. Read operations use `ActionRead`, write operations use `ActionUpdate` on `ResourceDeploymentConfig`. ## Governance Model | Control | Implementation | |---------|---------------| | **Global kill switch** | `enabled` defaults to `false` | | **Availability policy** | `force_on` (always injected), `default_on` (pre-selected), `default_off` (opt-in) | | **Per-chat selection** | `mcp_server_ids` on `CreateChatRequest` / `CreateChatMessageRequest` | | **Auth gate** | OAuth2 servers require per-user auth before tools are injected | | **Tool-level allow/deny** | Arrays on `mcp_server_configs` for granular tool filtering | | **Secrets encrypted at rest** | Uses `dbcrypt_keys` (same pattern as `chat_providers.api_key`) | ## Tests 8 test functions covering: - Full CRUD lifecycle (create, list, update, delete) - Non-admin visibility filtering (enabled-only, field redaction) - `auth_connected` population for OAuth2 vs non-OAuth2 servers - Availability policy validation (valid values + invalid rejection) - Unique slug enforcement (409 Conflict) - OAuth2 disconnect idempotency - Chat creation with `mcp_server_ids` persistence ## Known Limitations (Deferred) These are documented and intentional for an experimental feature: - **Audit logging** not yet wired — will add when feature stabilizes - **Cross-field validation** (e.g., OAuth2 fields required when `auth_type=oauth2`) — admin-only endpoint, will add when stabilizing - **`force_on` auto-injection** — query exists but not yet wired into chatd tool injection (follow-up) - **Additional test coverage** — 403 auth tests, GET-by-ID tests, callback CSRF tests planned for follow-up ## What's NOT in this PR - Frontend UI (admin panel + chat picker) - Actual MCP client connections (`chatd/chatmcp/` manager) - Tool injection into `chatloop/` --- coderd/chatd/chatd.go | 20 + coderd/chats.go | 72 ++ coderd/coderd.go | 21 +- coderd/database/check_constraint.go | 3 + coderd/database/dbauthz/dbauthz.go | 109 +++ coderd/database/dbauthz/dbauthz_test.go | 112 ++- coderd/database/dbmetrics/querymetrics.go | 120 +++ coderd/database/dbmock/dbmock.go | 222 +++++ coderd/database/dump.sql | 95 +- coderd/database/foreign_key_constraint.go | 9 + .../000447_mcp_server_configs.down.sql | 6 + .../000447_mcp_server_configs.up.sql | 75 ++ .../fixtures/000447_mcp_server_configs.up.sql | 48 + coderd/database/modelqueries.go | 1 + coderd/database/models.go | 45 + coderd/database/querier.go | 15 + coderd/database/querier_test.go | 10 +- coderd/database/queries.sql.go | 879 ++++++++++++++++- coderd/database/queries/chats.sql | 17 +- coderd/database/queries/mcpserverconfigs.sql | 213 ++++ coderd/database/sqlc.yaml | 22 + coderd/database/unique_constraint.go | 4 + coderd/mcp.go | 921 ++++++++++++++++++ coderd/mcp_test.go | 489 ++++++++++ coderd/workspaceproxies.go | 2 +- codersdk/chats.go | 3 + codersdk/mcp.go | 191 ++++ enterprise/coderd/workspaceproxy.go | 4 +- enterprise/dbcrypt/dbcrypt.go | 195 ++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 299 ++++++ site/src/api/queries/chats.test.ts | 1 + site/src/api/typesGenerated.ts | 100 ++ .../pages/AgentsPage/AgentDetail.stories.tsx | 1 + .../AgentDetail/ChatContext.test.tsx | 1 + .../AgentsPage/AgentDetail/TopBar.stories.tsx | 1 + .../AgentsPage/AgentDetailView.stories.tsx | 1 + .../AgentsPage/AgentsPageView.stories.tsx | 1 + .../AgentsPage/AgentsSidebar.stories.tsx | 1 + .../pages/AgentsPage/AgentsSidebar.test.tsx | 1 + 39 files changed, 4300 insertions(+), 30 deletions(-) create mode 100644 coderd/database/migrations/000447_mcp_server_configs.down.sql create mode 100644 coderd/database/migrations/000447_mcp_server_configs.up.sql create mode 100644 coderd/database/migrations/testdata/fixtures/000447_mcp_server_configs.up.sql create mode 100644 coderd/database/queries/mcpserverconfigs.sql create mode 100644 coderd/mcp.go create mode 100644 coderd/mcp_test.go create mode 100644 codersdk/mcp.go diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index 358ee756cb..98ce8e6b1c 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -379,6 +379,7 @@ type CreateOptions struct { ChatMode database.NullChatMode SystemPrompt string InitialUserContent []codersdk.ChatMessagePart + MCPServerIDs []uuid.UUID } // SendMessageBusyBehavior controls what happens when a chat is already active. @@ -401,6 +402,7 @@ type SendMessageOptions struct { Content []codersdk.ChatMessagePart ModelConfigID *uuid.UUID BusyBehavior SendMessageBusyBehavior + MCPServerIDs *[]uuid.UUID } // SendMessageResult contains the outcome of user message processing. @@ -450,6 +452,12 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C if len(opts.InitialUserContent) == 0 { return database.Chat{}, xerrors.New("initial user content is required") } + // Ensure MCPServerIDs is non-nil so pq.Array produces '{}' + // instead of SQL NULL, which violates the NOT NULL column + // constraint. + if opts.MCPServerIDs == nil { + opts.MCPServerIDs = []uuid.UUID{} + } var chat database.Chat txErr := p.db.InTx(func(tx database.Store) error { @@ -465,6 +473,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C LastModelConfigID: opts.ModelConfigID, Title: opts.Title, Mode: opts.ChatMode, + MCPServerIDs: opts.MCPServerIDs, }) if err != nil { return xerrors.Errorf("insert chat: %w", err) @@ -596,6 +605,17 @@ func (p *Server) SendMessage( modelConfigID = *opts.ModelConfigID } + // Update MCP server IDs on the chat when explicitly provided. + if opts.MCPServerIDs != nil { + lockedChat, err = tx.UpdateChatMCPServerIDs(ctx, database.UpdateChatMCPServerIDsParams{ + ID: opts.ChatID, + MCPServerIDs: *opts.MCPServerIDs, + }) + if err != nil { + return xerrors.Errorf("update chat mcp server ids: %w", err) + } + } + existingQueued, err := tx.GetChatQueuedMessages(ctx, opts.ChatID) if err != nil { return xerrors.Errorf("get queued messages: %w", err) diff --git a/coderd/chats.go b/coderd/chats.go index 3041c0b8ce..ac3d33d203 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -284,6 +284,41 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } + // Validate MCP server IDs exist. + if len(req.MCPServerIDs) > 0 { + //nolint:gocritic // Need to validate MCP server IDs exist. + existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to validate MCP server IDs.", + Detail: err.Error(), + }) + return + } + if len(existingConfigs) != len(req.MCPServerIDs) { + found := make(map[uuid.UUID]struct{}, len(existingConfigs)) + for _, c := range existingConfigs { + found[c.ID] = struct{}{} + } + var missing []string + for _, id := range req.MCPServerIDs { + if _, ok := found[id]; !ok { + missing = append(missing, id.String()) + } + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "One or more MCP server IDs are invalid.", + Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")), + }) + return + } + } + + mcpServerIDs := req.MCPServerIDs + if mcpServerIDs == nil { + mcpServerIDs = []uuid.UUID{} + } + chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{ OwnerID: apiKey.UserID, WorkspaceID: workspaceSelection.WorkspaceID, @@ -291,6 +326,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ModelConfigID: modelConfigID, SystemPrompt: api.resolvedChatSystemPrompt(ctx), InitialUserContent: contentBlocks, + MCPServerIDs: mcpServerIDs, }) if err != nil { if maybeWriteLimitErr(ctx, rw, err) { @@ -1456,6 +1492,36 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } + // Validate MCP server IDs exist. + if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 { + //nolint:gocritic // Need to validate MCP server IDs exist. + existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to validate MCP server IDs.", + Detail: err.Error(), + }) + return + } + if len(existingConfigs) != len(*req.MCPServerIDs) { + found := make(map[uuid.UUID]struct{}, len(existingConfigs)) + for _, c := range existingConfigs { + found[c.ID] = struct{}{} + } + var missing []string + for _, id := range *req.MCPServerIDs { + if _, ok := found[id]; !ok { + missing = append(missing, id.String()) + } + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "One or more MCP server IDs are invalid.", + Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")), + }) + return + } + } + sendResult, sendErr := api.chatDaemon.SendMessage( ctx, chatd.SendMessageOptions{ @@ -1464,6 +1530,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { Content: contentBlocks, ModelConfigID: req.ModelConfigID, BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + MCPServerIDs: req.MCPServerIDs, }, ) if sendErr != nil { @@ -2979,6 +3046,10 @@ func truncateRunes(value string, maxLen int) string { } func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat { + mcpServerIDs := c.MCPServerIDs + if mcpServerIDs == nil { + mcpServerIDs = []uuid.UUID{} + } chat := codersdk.Chat{ ID: c.ID, OwnerID: c.OwnerID, @@ -2988,6 +3059,7 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk. Archived: c.Archived, CreatedAt: c.CreatedAt, UpdatedAt: c.UpdatedAt, + MCPServerIDs: mcpServerIDs, } if c.LastError.Valid { chat.LastError = &c.LastError.String diff --git a/coderd/coderd.go b/coderd/coderd.go index fc9f298cf7..84144614e7 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1233,10 +1233,27 @@ func New(options *Options) *API { r.Route("/mcp", func(r chi.Router) { r.Use( apiKeyMiddleware, - httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP), ) + // MCP server configuration endpoints. + r.Route("/servers", func(r chi.Router) { + r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents)) + r.Get("/", api.listMCPServerConfigs) + r.Post("/", api.createMCPServerConfig) + r.Route("/{mcpServer}", func(r chi.Router) { + r.Get("/", api.getMCPServerConfig) + r.Patch("/", api.updateMCPServerConfig) + r.Delete("/", api.deleteMCPServerConfig) + // OAuth2 user flow + r.Get("/oauth2/connect", api.mcpServerOAuth2Connect) + r.Get("/oauth2/callback", api.mcpServerOAuth2Callback) + r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect) + }) + }) // MCP HTTP transport endpoint with mandatory authentication - r.Mount("/http", api.mcpHTTPHandler()) + r.Route("/http", func(r chi.Router) { + r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP)) + r.Mount("/", api.mcpHTTPHandler()) + }) }) r.Route("/watch-all-workspacebuilds", func(r chi.Router) { r.Use( diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index af6d0fc248..51f8326779 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -20,6 +20,9 @@ const ( CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users + CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs + CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs + CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 931a8cabfe..5a2a5981cb 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1691,6 +1691,13 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error { return q.db.CleanTailnetTunnels(ctx) } +func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil { + return err + } + return q.db.CleanupDeletedMCPServerIDsFromChats(ctx) +} + func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) if err != nil { @@ -1920,6 +1927,20 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) { return id, nil } +func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.DeleteMCPServerConfigByID(ctx, id) +} + +func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.DeleteMCPServerUserToken(ctx, arg) +} + func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error { if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil { return err @@ -2763,6 +2784,13 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP return q.db.GetEnabledChatProviders(ctx) } +func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return nil, err + } + return q.db.GetEnabledMCPServerConfigs(ctx) +} + func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) { return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg) } @@ -2821,6 +2849,13 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg) } +func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return nil, err + } + return q.db.GetForcedMCPServerConfigs(ctx) +} + func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID) } @@ -2961,6 +2996,48 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) { return q.db.GetLogoURL(ctx) } +func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerConfig{}, err + } + return q.db.GetMCPServerConfigByID(ctx, id) +} + +func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerConfig{}, err + } + return q.db.GetMCPServerConfigBySlug(ctx, slug) +} + +func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return nil, err + } + return q.db.GetMCPServerConfigs(ctx) +} + +func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return nil, err + } + return q.db.GetMCPServerConfigsByIDs(ctx, ids) +} + +func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerUserToken{}, err + } + return q.db.GetMCPServerUserToken(ctx, arg) +} + +func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return nil, err + } + return q.db.GetMCPServerUserTokensByUserID(ctx, userID) +} + func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil { return nil, err @@ -4727,6 +4804,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP return q.db.InsertLicense(ctx, arg) } +func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerConfig{}, err + } + return q.db.InsertMCPServerConfig(ctx, arg) +} + func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) { if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil { return database.WorkspaceAgentMemoryResourceMonitor{}, err @@ -5486,6 +5570,17 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh return q.db.UpdateChatHeartbeat(ctx, arg) } +func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatMCPServerIDs(ctx, arg) +} + func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) { // Authorize update on the parent chat of the edited message. msg, err := q.db.GetChatMessageByID(ctx, arg.ID) @@ -5650,6 +5745,13 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args) } +func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerConfig{}, err + } + return q.db.UpdateMCPServerConfig(ctx, arg) +} + func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { // Authorized fetch will check that the actor has read access to the org member since the org member is returned. member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{ @@ -6695,6 +6797,13 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error { return q.db.UpsertLogoURL(ctx, value) } +func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return database.MCPServerUserToken{}, err + } + return q.db.UpsertMCPServerUserToken(ctx, arg) +} + func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error { if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index df72ccc98b..cc76fe7c8d 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1005,6 +1005,114 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes() check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes() + check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate) + })) + s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) + s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.DeleteMCPServerUserTokenParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) + s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB}) + })) + s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB}) + })) + s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + config := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes() + check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config) + })) + s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + slug := "test-mcp-server" + config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug}) + dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes() + check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config) + })) + s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB}) + })) + s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + ids := []uuid.UUID{configA.ID, configB.ID} + dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB}) + })) + s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.GetMCPServerUserTokenParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token) + })) + s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + userID := uuid.New() + tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})} + dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes() + check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens) + })) + s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.InsertMCPServerConfigParams{ + DisplayName: "Test MCP Server", + Slug: "test-mcp-server", + } + config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug}) + dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) + })) + s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatMCPServerIDsParams{ + ID: chat.ID, + MCPServerIDs: []uuid.UUID{uuid.New()}, + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) + s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + config := testutil.Fake(s.T(), faker, database.MCPServerConfig{}) + arg := database.UpdateMCPServerConfigParams{ + ID: config.ID, + DisplayName: "Updated MCP Server", + Slug: "updated-mcp-server", + } + dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) + })) + s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.UpsertMCPServerUserTokenParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + AccessToken: "test-access-token", + TokenType: "bearer", + } + token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token) + })) } func (s *MethodTestSuite) TestFile() { @@ -1381,8 +1489,8 @@ func (s *MethodTestSuite) TestLicense() { check.Args().Asserts().Returns("value") })) s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes() - check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}) + dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes() + check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}) })) s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 1d30d1f3ec..40ad46275a 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -264,6 +264,14 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error { return r0 } +func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error { + start := time.Now() + r0 := m.s.CleanupDeletedMCPServerIDsFromChats(ctx) + m.queryLatencies.WithLabelValues("CleanupDeletedMCPServerIDsFromChats").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CleanupDeletedMCPServerIDsFromChats").Inc() + return r0 +} + func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { start := time.Now() r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg) @@ -480,6 +488,22 @@ func (m queryMetricsStore) DeleteLicense(ctx context.Context, id int32) (int32, return r0, r1 } +func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error { + start := time.Now() + r0 := m.s.DeleteMCPServerConfigByID(ctx, id) + m.queryLatencies.WithLabelValues("DeleteMCPServerConfigByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerConfigByID").Inc() + return r0 +} + +func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { + start := time.Now() + r0 := m.s.DeleteMCPServerUserToken(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteMCPServerUserToken").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserToken").Inc() + return r0 +} + func (m queryMetricsStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error { start := time.Now() r0 := m.s.DeleteOAuth2ProviderAppByClientID(ctx, id) @@ -1328,6 +1352,14 @@ func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]datab return r0, r1 } +func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx) + m.queryLatencies.WithLabelValues("GetEnabledMCPServerConfigs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledMCPServerConfigs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) { start := time.Now() r0, r1 := m.s.GetExternalAuthLink(ctx, arg) @@ -1384,6 +1416,14 @@ func (m queryMetricsStore) GetFilteredInboxNotificationsByUserID(ctx context.Con return r0, r1 } +func (m queryMetricsStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetForcedMCPServerConfigs(ctx) + m.queryLatencies.WithLabelValues("GetForcedMCPServerConfigs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetForcedMCPServerConfigs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { start := time.Now() r0, r1 := m.s.GetGitSSHKey(ctx, userID) @@ -1544,6 +1584,54 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) { return r0, r1 } +func (m queryMetricsStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerConfigByID(ctx, id) + m.queryLatencies.WithLabelValues("GetMCPServerConfigByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigByID").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerConfigBySlug(ctx, slug) + m.queryLatencies.WithLabelValues("GetMCPServerConfigBySlug").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigBySlug").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerConfigs(ctx) + m.queryLatencies.WithLabelValues("GetMCPServerConfigs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigs").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerConfigsByIDs(ctx, ids) + m.queryLatencies.WithLabelValues("GetMCPServerConfigsByIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigsByIDs").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserToken(ctx, arg) + m.queryLatencies.WithLabelValues("GetMCPServerUserToken").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserToken").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserTokensByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetMCPServerUserTokensByUserID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserTokensByUserID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { start := time.Now() r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg) @@ -3184,6 +3272,14 @@ func (m queryMetricsStore) InsertLicense(ctx context.Context, arg database.Inser return r0, r1 } +func (m queryMetricsStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.InsertMCPServerConfig(ctx, arg) + m.queryLatencies.WithLabelValues("InsertMCPServerConfig").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertMCPServerConfig").Inc() + return r0, r1 +} + func (m queryMetricsStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) { start := time.Now() r0, r1 := m.s.InsertMemoryResourceMonitor(ctx, arg) @@ -3856,6 +3952,14 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database return r0, r1 } +func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatMCPServerIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMCPServerIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.UpdateChatMessageByID(ctx, arg) @@ -3960,6 +4064,14 @@ func (m queryMetricsStore) UpdateInboxNotificationReadStatus(ctx context.Context return r0 } +func (m queryMetricsStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) { + start := time.Now() + r0, r1 := m.s.UpdateMCPServerConfig(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateMCPServerConfig").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateMCPServerConfig").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { start := time.Now() r0, r1 := m.s.UpdateMemberRoles(ctx, arg) @@ -4696,6 +4808,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro return r0 } +func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + start := time.Now() + r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertMCPServerUserToken").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserToken").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error { start := time.Now() r0 := m.s.UpsertNotificationReportGeneratorLog(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 4b3831307c..0df5aed2f1 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -334,6 +334,20 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx) } +// CleanupDeletedMCPServerIDsFromChats mocks base method. +func (m *MockStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupDeletedMCPServerIDsFromChats", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupDeletedMCPServerIDsFromChats indicates an expected call of CleanupDeletedMCPServerIDsFromChats. +func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx) +} + // CountAIBridgeInterceptions mocks base method. func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { m.ctrl.T.Helper() @@ -769,6 +783,34 @@ func (mr *MockStoreMockRecorder) DeleteLicense(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLicense", reflect.TypeOf((*MockStore)(nil).DeleteLicense), ctx, id) } +// DeleteMCPServerConfigByID mocks base method. +func (m *MockStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteMCPServerConfigByID", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteMCPServerConfigByID indicates an expected call of DeleteMCPServerConfigByID. +func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id) +} + +// DeleteMCPServerUserToken mocks base method. +func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteMCPServerUserToken", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteMCPServerUserToken indicates an expected call of DeleteMCPServerUserToken. +func (mr *MockStoreMockRecorder) DeleteMCPServerUserToken(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserToken), ctx, arg) +} + // DeleteOAuth2ProviderAppByClientID mocks base method. func (m *MockStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error { m.ctrl.T.Helper() @@ -2437,6 +2479,21 @@ func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx) } +// GetEnabledMCPServerConfigs mocks base method. +func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnabledMCPServerConfigs", ctx) + ret0, _ := ret[0].([]database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnabledMCPServerConfigs indicates an expected call of GetEnabledMCPServerConfigs. +func (mr *MockStoreMockRecorder) GetEnabledMCPServerConfigs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledMCPServerConfigs), ctx) +} + // GetExternalAuthLink mocks base method. func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) { m.ctrl.T.Helper() @@ -2542,6 +2599,21 @@ func (mr *MockStoreMockRecorder) GetFilteredInboxNotificationsByUserID(ctx, arg return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetFilteredInboxNotificationsByUserID), ctx, arg) } +// GetForcedMCPServerConfigs mocks base method. +func (m *MockStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetForcedMCPServerConfigs", ctx) + ret0, _ := ret[0].([]database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetForcedMCPServerConfigs indicates an expected call of GetForcedMCPServerConfigs. +func (mr *MockStoreMockRecorder) GetForcedMCPServerConfigs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForcedMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetForcedMCPServerConfigs), ctx) +} + // GetGitSSHKey mocks base method. func (m *MockStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) { m.ctrl.T.Helper() @@ -2842,6 +2914,96 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx) } +// GetMCPServerConfigByID mocks base method. +func (m *MockStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerConfigByID", ctx, id) + ret0, _ := ret[0].(database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerConfigByID indicates an expected call of GetMCPServerConfigByID. +func (mr *MockStoreMockRecorder) GetMCPServerConfigByID(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigByID), ctx, id) +} + +// GetMCPServerConfigBySlug mocks base method. +func (m *MockStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerConfigBySlug", ctx, slug) + ret0, _ := ret[0].(database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerConfigBySlug indicates an expected call of GetMCPServerConfigBySlug. +func (mr *MockStoreMockRecorder) GetMCPServerConfigBySlug(ctx, slug any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigBySlug", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigBySlug), ctx, slug) +} + +// GetMCPServerConfigs mocks base method. +func (m *MockStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerConfigs", ctx) + ret0, _ := ret[0].([]database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerConfigs indicates an expected call of GetMCPServerConfigs. +func (mr *MockStoreMockRecorder) GetMCPServerConfigs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigs), ctx) +} + +// GetMCPServerConfigsByIDs mocks base method. +func (m *MockStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerConfigsByIDs", ctx, ids) + ret0, _ := ret[0].([]database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerConfigsByIDs indicates an expected call of GetMCPServerConfigsByIDs. +func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids) +} + +// GetMCPServerUserToken mocks base method. +func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserToken", ctx, arg) + ret0, _ := ret[0].(database.MCPServerUserToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserToken indicates an expected call of GetMCPServerUserToken. +func (mr *MockStoreMockRecorder) GetMCPServerUserToken(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserToken), ctx, arg) +} + +// GetMCPServerUserTokensByUserID mocks base method. +func (m *MockStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserTokensByUserID", ctx, userID) + ret0, _ := ret[0].([]database.MCPServerUserToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserTokensByUserID indicates an expected call of GetMCPServerUserTokensByUserID. +func (mr *MockStoreMockRecorder) GetMCPServerUserTokensByUserID(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserTokensByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserTokensByUserID), ctx, userID) +} + // GetNotificationMessagesByStatus mocks base method. func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { m.ctrl.T.Helper() @@ -5957,6 +6119,21 @@ func (mr *MockStoreMockRecorder) InsertLicense(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), ctx, arg) } +// InsertMCPServerConfig mocks base method. +func (m *MockStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertMCPServerConfig", ctx, arg) + ret0, _ := ret[0].(database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertMCPServerConfig indicates an expected call of InsertMCPServerConfig. +func (mr *MockStoreMockRecorder) InsertMCPServerConfig(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMCPServerConfig", reflect.TypeOf((*MockStore)(nil).InsertMCPServerConfig), ctx, arg) +} + // InsertMemoryResourceMonitor mocks base method. func (m *MockStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) { m.ctrl.T.Helper() @@ -7256,6 +7433,21 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg) } +// UpdateChatMCPServerIDs mocks base method. +func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatMCPServerIDs", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatMCPServerIDs indicates an expected call of UpdateChatMCPServerIDs. +func (mr *MockStoreMockRecorder) UpdateChatMCPServerIDs(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMCPServerIDs", reflect.TypeOf((*MockStore)(nil).UpdateChatMCPServerIDs), ctx, arg) +} + // UpdateChatMessageByID mocks base method. func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) { m.ctrl.T.Helper() @@ -7449,6 +7641,21 @@ func (mr *MockStoreMockRecorder) UpdateInboxNotificationReadStatus(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInboxNotificationReadStatus", reflect.TypeOf((*MockStore)(nil).UpdateInboxNotificationReadStatus), ctx, arg) } +// UpdateMCPServerConfig mocks base method. +func (m *MockStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateMCPServerConfig", ctx, arg) + ret0, _ := ret[0].(database.MCPServerConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateMCPServerConfig indicates an expected call of UpdateMCPServerConfig. +func (mr *MockStoreMockRecorder) UpdateMCPServerConfig(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateMCPServerConfig), ctx, arg) +} + // UpdateMemberRoles mocks base method. func (m *MockStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { m.ctrl.T.Helper() @@ -8773,6 +8980,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value) } +// UpsertMCPServerUserToken mocks base method. +func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertMCPServerUserToken", ctx, arg) + ret0, _ := ret[0].(database.MCPServerUserToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertMCPServerUserToken indicates an expected call of UpsertMCPServerUserToken. +func (mr *MockStoreMockRecorder) UpsertMCPServerUserToken(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserToken), ctx, arg) +} + // UpsertNotificationReportGeneratorLog mocks base method. func (m *MockStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 836767f60f..f1b7f3763b 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1393,7 +1393,8 @@ CREATE TABLE chats ( last_model_config_id uuid NOT NULL, archived boolean DEFAULT false NOT NULL, last_error text, - mode chat_mode + mode chat_mode, + mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL ); CREATE TABLE connection_logs ( @@ -1670,6 +1671,53 @@ CREATE SEQUENCE licenses_id_seq ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id; +CREATE TABLE mcp_server_configs ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + display_name text NOT NULL, + slug text NOT NULL, + description text DEFAULT ''::text NOT NULL, + icon_url text DEFAULT ''::text NOT NULL, + transport text DEFAULT 'streamable_http'::text NOT NULL, + url text NOT NULL, + auth_type text DEFAULT 'none'::text NOT NULL, + oauth2_client_id text DEFAULT ''::text NOT NULL, + oauth2_client_secret text DEFAULT ''::text NOT NULL, + oauth2_client_secret_key_id text, + oauth2_auth_url text DEFAULT ''::text NOT NULL, + oauth2_token_url text DEFAULT ''::text NOT NULL, + oauth2_scopes text DEFAULT ''::text NOT NULL, + api_key_header text DEFAULT 'Authorization'::text NOT NULL, + api_key_value text DEFAULT ''::text NOT NULL, + api_key_value_key_id text, + custom_headers text DEFAULT '{}'::text NOT NULL, + custom_headers_key_id text, + tool_allow_list text[] DEFAULT '{}'::text[] NOT NULL, + tool_deny_list text[] DEFAULT '{}'::text[] NOT NULL, + availability text DEFAULT 'default_off'::text NOT NULL, + enabled boolean DEFAULT false NOT NULL, + created_by uuid, + updated_by uuid, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))), + CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))), + CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text]))) +); + +CREATE TABLE mcp_server_user_tokens ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + mcp_server_config_id uuid NOT NULL, + user_id uuid NOT NULL, + access_token text NOT NULL, + access_token_key_id text, + refresh_token text DEFAULT ''::text NOT NULL, + refresh_token_key_id text, + token_type text DEFAULT 'Bearer'::text NOT NULL, + expiry timestamp with time zone, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + CREATE TABLE notification_messages ( id uuid NOT NULL, notification_template_id uuid NOT NULL, @@ -3343,6 +3391,18 @@ ALTER TABLE ONLY licenses ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug); + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id); + ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id); @@ -3691,6 +3751,12 @@ CREATE INDEX idx_inbox_notifications_user_id_read_at ON inbox_notifications USIN CREATE INDEX idx_inbox_notifications_user_id_template_id_targets ON inbox_notifications USING btree (user_id, template_id, targets); +CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (enabled) WHERE (enabled = true); + +CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text)); + +CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id); + CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status); CREATE INDEX idx_organization_member_organization_id_uuid ON organization_members USING btree (organization_id); @@ -4015,6 +4081,33 @@ ALTER TABLE ONLY jfrog_xray_scans ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL; + +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_configs + ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL; + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE; + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_user_tokens + ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index cbb47ce680..b6095e0547 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -40,6 +40,15 @@ const ( ForeignKeyInboxNotificationsUserID ForeignKeyConstraint = "inbox_notifications_user_id_fkey" // ALTER TABLE ONLY inbox_notifications ADD CONSTRAINT inbox_notifications_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE; ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; + ForeignKeyMcpServerConfigsAPIKeyValueKeyID ForeignKeyConstraint = "mcp_server_configs_api_key_value_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerConfigsCreatedBy ForeignKeyConstraint = "mcp_server_configs_created_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL; + ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL; + ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE; + ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerUserTokensUserID ForeignKeyConstraint = "mcp_server_user_tokens_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyNotificationMessagesNotificationTemplateID ForeignKeyConstraint = "notification_messages_notification_template_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE; ForeignKeyNotificationMessagesUserID ForeignKeyConstraint = "notification_messages_user_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyNotificationPreferencesNotificationTemplateID ForeignKeyConstraint = "notification_preferences_notification_template_id_fkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000447_mcp_server_configs.down.sql b/coderd/database/migrations/000447_mcp_server_configs.down.sql new file mode 100644 index 0000000000..ebf2ee1b58 --- /dev/null +++ b/coderd/database/migrations/000447_mcp_server_configs.down.sql @@ -0,0 +1,6 @@ +ALTER TABLE chats DROP COLUMN IF EXISTS mcp_server_ids; +DROP INDEX IF EXISTS idx_mcp_server_configs_enabled; +DROP INDEX IF EXISTS idx_mcp_server_configs_forced; +DROP INDEX IF EXISTS idx_mcp_server_user_tokens_user_id; +DROP TABLE IF EXISTS mcp_server_user_tokens; +DROP TABLE IF EXISTS mcp_server_configs; diff --git a/coderd/database/migrations/000447_mcp_server_configs.up.sql b/coderd/database/migrations/000447_mcp_server_configs.up.sql new file mode 100644 index 0000000000..f8a6c22b0f --- /dev/null +++ b/coderd/database/migrations/000447_mcp_server_configs.up.sql @@ -0,0 +1,75 @@ +CREATE TABLE mcp_server_configs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + + -- Display + display_name TEXT NOT NULL, + slug TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + icon_url TEXT NOT NULL DEFAULT '', + + -- Connection + transport TEXT NOT NULL DEFAULT 'streamable_http' + CHECK (transport IN ('streamable_http', 'sse')), + url TEXT NOT NULL, + + -- Authentication + auth_type TEXT NOT NULL DEFAULT 'none' + CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers')), + + -- OAuth2 config (when auth_type = 'oauth2') + oauth2_client_id TEXT NOT NULL DEFAULT '', + oauth2_client_secret TEXT NOT NULL DEFAULT '', + oauth2_client_secret_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + oauth2_auth_url TEXT NOT NULL DEFAULT '', + oauth2_token_url TEXT NOT NULL DEFAULT '', + oauth2_scopes TEXT NOT NULL DEFAULT '', + + -- API key config (when auth_type = 'api_key') + api_key_header TEXT NOT NULL DEFAULT 'Authorization', + api_key_value TEXT NOT NULL DEFAULT '', + api_key_value_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + + -- Custom headers (when auth_type = 'custom_headers') + custom_headers TEXT NOT NULL DEFAULT '{}', + custom_headers_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + + -- Tool governance + tool_allow_list TEXT[] NOT NULL DEFAULT '{}', + tool_deny_list TEXT[] NOT NULL DEFAULT '{}', + + -- Availability policy + availability TEXT NOT NULL DEFAULT 'default_off' + CHECK (availability IN ('force_on', 'default_on', 'default_off')), + + -- Lifecycle + enabled BOOLEAN NOT NULL DEFAULT false, + created_by UUID REFERENCES users(id) ON DELETE SET NULL, + updated_by UUID REFERENCES users(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE mcp_server_user_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + mcp_server_config_id UUID NOT NULL REFERENCES mcp_server_configs(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + + access_token TEXT NOT NULL, + access_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + refresh_token TEXT NOT NULL DEFAULT '', + refresh_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + token_type TEXT NOT NULL DEFAULT 'Bearer', + expiry TIMESTAMPTZ, + + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + + UNIQUE (mcp_server_config_id, user_id) +); + +-- Add MCP server selection to chats (per-chat, like model_config_id) +ALTER TABLE chats ADD COLUMN mcp_server_ids UUID[] NOT NULL DEFAULT '{}'; + +CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs(enabled) WHERE enabled = TRUE; +CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs(enabled, availability) WHERE enabled = TRUE AND availability = 'force_on'; +CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens(user_id); diff --git a/coderd/database/migrations/testdata/fixtures/000447_mcp_server_configs.up.sql b/coderd/database/migrations/testdata/fixtures/000447_mcp_server_configs.up.sql new file mode 100644 index 0000000000..c3aea6c5dc --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000447_mcp_server_configs.up.sql @@ -0,0 +1,48 @@ +INSERT INTO mcp_server_configs ( + id, + display_name, + slug, + url, + transport, + auth_type, + availability, + enabled, + created_by, + updated_by, + created_at, + updated_at +) VALUES ( + 'a1b2c3d4-e5f6-7890-abcd-ef1234567890', + 'Fixture MCP Server', + 'fixture-mcp-server', + 'https://mcp.example.com/sse', + 'sse', + 'none', + 'default_on', + TRUE, + '30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com + '30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com + '2024-01-01 00:00:00+00', + '2024-01-01 00:00:00+00' +); + +INSERT INTO mcp_server_user_tokens ( + id, + mcp_server_config_id, + user_id, + access_token, + token_type, + created_at, + updated_at +) +SELECT + 'b2c3d4e5-f6a7-8901-bcde-f12345678901', + 'a1b2c3d4-e5f6-7890-abcd-ef1234567890', + id, + 'fixture-access-token', + 'Bearer', + '2024-01-01 00:00:00+00', + '2024-01-01 00:00:00+00' +FROM users +ORDER BY created_at, id +LIMIT 1; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 8bceef79eb..c39b06202b 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -787,6 +787,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index ec21b110db..d24c13486c 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4167,6 +4167,7 @@ type Chat struct { Archived bool `db:"archived" json:"archived"` LastError sql.NullString `db:"last_error" json:"last_error"` Mode NullChatMode `db:"mode" json:"mode"` + MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` } type ChatDiffStatus struct { @@ -4452,6 +4453,50 @@ type License struct { UUID uuid.UUID `db:"uuid" json:"uuid"` } +type MCPServerConfig struct { + ID uuid.UUID `db:"id" json:"id"` + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` + UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +type MCPServerUserToken struct { + ID uuid.UUID `db:"id" json:"id"` + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AccessToken string `db:"access_token" json:"access_token"` + AccessTokenKeyID sql.NullString `db:"access_token_key_id" json:"access_token_key_id"` + RefreshToken string `db:"refresh_token" json:"refresh_token"` + RefreshTokenKeyID sql.NullString `db:"refresh_token_key_id" json:"refresh_token_key_id"` + TokenType string `db:"token_type" json:"token_type"` + Expiry sql.NullTime `db:"expiry" json:"expiry"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + type NotificationMessage struct { ID uuid.UUID `db:"id" json:"id"` NotificationTemplateID uuid.UUID `db:"notification_template_id" json:"notification_template_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 90786a7d7d..477f66d6c4 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -74,6 +74,7 @@ type sqlcQuerier interface { CleanTailnetCoordinators(ctx context.Context) error CleanTailnetLostPeers(ctx context.Context) error CleanTailnetTunnels(ctx context.Context) error + CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) @@ -110,6 +111,8 @@ type sqlcQuerier interface { DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) + DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error + DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error @@ -269,6 +272,7 @@ type sqlcQuerier interface { GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error) GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) + GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error) @@ -284,6 +288,7 @@ type sqlcQuerier interface { // param created_at_opt: The created_at timestamp to filter by. This parameter is usd for pagination - it fetches notifications created before the specified timestamp if it is not the zero value // param limit_opt: The limit of notifications to fetch. If the limit is not specified, it defaults to 25 GetFilteredInboxNotificationsByUserID(ctx context.Context, arg GetFilteredInboxNotificationsByUserIDParams) ([]InboxNotification, error) + GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) @@ -312,6 +317,12 @@ type sqlcQuerier interface { GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) + GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error) + GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error) + GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) + GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error) + GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error) + GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) // Fetch the notification report generator log indicating recent activity. GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error) @@ -675,6 +686,7 @@ type sqlcQuerier interface { InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error InsertInboxNotification(ctx context.Context, arg InsertInboxNotificationParams) (InboxNotification, error) InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error) + InsertMCPServerConfig(ctx context.Context, arg InsertMCPServerConfigParams) (MCPServerConfig, error) InsertMemoryResourceMonitor(ctx context.Context, arg InsertMemoryResourceMonitorParams) (WorkspaceAgentMemoryResourceMonitor, error) // Inserts any group by name that does not exist. All new groups are given // a random uuid, are inserted into the same organization. They have the default @@ -796,6 +808,7 @@ type sqlcQuerier interface { // Bumps the heartbeat timestamp for a running chat so that other // replicas know the worker is still alive. UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) + UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error) UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) @@ -813,6 +826,7 @@ type sqlcQuerier interface { UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) UpdateInboxNotificationReadStatus(ctx context.Context, arg UpdateInboxNotificationReadStatusParams) error + UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPServerConfigParams) (MCPServerConfig, error) UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) UpdateMemoryResourceMonitor(ctx context.Context, arg UpdateMemoryResourceMonitorParams) error UpdateNotificationTemplateMethodByID(ctx context.Context, arg UpdateNotificationTemplateMethodByIDParams) (NotificationTemplate, error) @@ -917,6 +931,7 @@ type sqlcQuerier interface { UpsertHealthSettings(ctx context.Context, value string) error UpsertLastUpdateCheck(ctx context.Context, value string) error UpsertLogoURL(ctx context.Context, value string) error + UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error) // Insert or update notification report generator logs with recent activity. UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error UpsertNotificationsSettings(ctx context.Context, value string) error diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index ec5bcfe988..d3560afd72 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -1653,12 +1653,12 @@ func TestDefaultProxy(t *testing.T) { require.NoError(t, err, "get def proxy") require.Equal(t, defProxy.DisplayName, "Default") - require.Equal(t, defProxy.IconUrl, "/emojis/1f3e1.png") + require.Equal(t, defProxy.IconURL, "/emojis/1f3e1.png") // Set the proxy values args := database.UpsertDefaultProxyParams{ DisplayName: "displayname", - IconUrl: "/icon.png", + IconURL: "/icon.png", } err = db.UpsertDefaultProxy(ctx, args) require.NoError(t, err, "insert def proxy") @@ -1666,12 +1666,12 @@ func TestDefaultProxy(t *testing.T) { defProxy, err = db.GetDefaultProxyConfig(ctx) require.NoError(t, err, "get def proxy") require.Equal(t, defProxy.DisplayName, args.DisplayName) - require.Equal(t, defProxy.IconUrl, args.IconUrl) + require.Equal(t, defProxy.IconURL, args.IconURL) // Upsert values args = database.UpsertDefaultProxyParams{ DisplayName: "newdisplayname", - IconUrl: "/newicon.png", + IconURL: "/newicon.png", } err = db.UpsertDefaultProxy(ctx, args) require.NoError(t, err, "upsert def proxy") @@ -1679,7 +1679,7 @@ func TestDefaultProxy(t *testing.T) { defProxy, err = db.GetDefaultProxyConfig(ctx) require.NoError(t, err, "get def proxy") require.Equal(t, defProxy.DisplayName, args.DisplayName) - require.Equal(t, defProxy.IconUrl, args.IconUrl) + require.Equal(t, defProxy.IconURL, args.IconURL) // Ensure other site configs are the same found, err := db.GetDeploymentID(ctx) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 50ce3430be..9421b60260 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3523,7 +3523,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids ` type AcquireChatsParams struct { @@ -3560,6 +3560,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ); err != nil { return nil, err } @@ -3787,7 +3788,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE @@ -3814,12 +3815,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -3842,6 +3844,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } @@ -4681,7 +4684,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE @@ -4764,6 +4767,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ); err != nil { return nil, err } @@ -4828,7 +4832,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE @@ -4864,6 +4868,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ); err != nil { return nil, err } @@ -4926,7 +4931,8 @@ INSERT INTO chats ( root_chat_id, last_model_config_id, title, - mode + mode, + mcp_server_ids ) VALUES ( $1::uuid, $2::uuid, @@ -4934,10 +4940,11 @@ INSERT INTO chats ( $4::uuid, $5::uuid, $6::text, - $7::chat_mode + $7::chat_mode, + COALESCE($8::uuid[], '{}'::uuid[]) ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids ` type InsertChatParams struct { @@ -4948,6 +4955,7 @@ type InsertChatParams struct { LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` Title string `db:"title" json:"title"` Mode NullChatMode `db:"mode" json:"mode"` + MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` } func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) { @@ -4959,6 +4967,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat arg.LastModelConfigID, arg.Title, arg.Mode, + pq.Array(arg.MCPServerIDs), ) var i Chat err := row.Scan( @@ -4978,6 +4987,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } @@ -5368,7 +5378,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids ` type UpdateChatByIDParams struct { @@ -5396,6 +5406,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } @@ -5426,6 +5437,48 @@ func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHear return result.RowsAffected() } +const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one +UPDATE + chats +SET + mcp_server_ids = $1::uuid[], + updated_at = NOW() +WHERE + id = $2::uuid +RETURNING + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids +` + +type UpdateChatMCPServerIDsParams struct { + MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatMCPServerIDs, pq.Array(arg.MCPServerIDs), arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + ) + return i, err +} + const updateChatMessageByID = `-- name: UpdateChatMessageByID :one UPDATE chat_messages @@ -5485,7 +5538,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids ` type UpdateChatStatusParams struct { @@ -5524,6 +5577,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } @@ -5537,7 +5591,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids ` type UpdateChatWorkspaceParams struct { @@ -5565,6 +5619,7 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork &i.Archived, &i.LastError, &i.Mode, + pq.Array(&i.MCPServerIDs), ) return i, err } @@ -9382,6 +9437,800 @@ func (q *sqlQuerier) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock i return pg_try_advisory_xact_lock, err } +const cleanupDeletedMCPServerIDsFromChats = `-- name: CleanupDeletedMCPServerIDsFromChats :exec +UPDATE chats +SET mcp_server_ids = ( + SELECT COALESCE(array_agg(sid), '{}') + FROM unnest(chats.mcp_server_ids) AS sid + WHERE sid IN (SELECT id FROM mcp_server_configs) +) +WHERE mcp_server_ids != '{}' + AND NOT (mcp_server_ids <@ COALESCE((SELECT array_agg(id) FROM mcp_server_configs), '{}')) +` + +func (q *sqlQuerier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, cleanupDeletedMCPServerIDsFromChats) + return err +} + +const deleteMCPServerConfigByID = `-- name: DeleteMCPServerConfigByID :exec +DELETE FROM + mcp_server_configs +WHERE + id = $1::uuid +` + +func (q *sqlQuerier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteMCPServerConfigByID, id) + return err +} + +const deleteMCPServerUserToken = `-- name: DeleteMCPServerUserToken :exec +DELETE FROM + mcp_server_user_tokens +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type DeleteMCPServerUserTokenParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error { + _, err := q.db.ExecContext(ctx, deleteMCPServerUserToken, arg.MCPServerConfigID, arg.UserID) + return err +} + +const getEnabledMCPServerConfigs = `-- name: GetEnabledMCPServerConfigs :many +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +WHERE + enabled = TRUE +ORDER BY + display_name ASC +` + +func (q *sqlQuerier) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) { + rows, err := q.db.QueryContext(ctx, getEnabledMCPServerConfigs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []MCPServerConfig + for rows.Next() { + var i MCPServerConfig + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getForcedMCPServerConfigs = `-- name: GetForcedMCPServerConfigs :many +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +WHERE + enabled = TRUE + AND availability = 'force_on' +ORDER BY + display_name ASC +` + +func (q *sqlQuerier) GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) { + rows, err := q.db.QueryContext(ctx, getForcedMCPServerConfigs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []MCPServerConfig + for rows.Next() { + var i MCPServerConfig + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getMCPServerConfigByID = `-- name: GetMCPServerConfigByID :one +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +WHERE + id = $1::uuid +` + +func (q *sqlQuerier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error) { + row := q.db.QueryRowContext(ctx, getMCPServerConfigByID, id) + var i MCPServerConfig + err := row.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getMCPServerConfigBySlug = `-- name: GetMCPServerConfigBySlug :one +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +WHERE + slug = $1::text +` + +func (q *sqlQuerier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error) { + row := q.db.QueryRowContext(ctx, getMCPServerConfigBySlug, slug) + var i MCPServerConfig + err := row.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getMCPServerConfigs = `-- name: GetMCPServerConfigs :many +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +ORDER BY + display_name ASC +` + +func (q *sqlQuerier) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) { + rows, err := q.db.QueryContext(ctx, getMCPServerConfigs) + if err != nil { + return nil, err + } + defer rows.Close() + var items []MCPServerConfig + for rows.Next() { + var i MCPServerConfig + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getMCPServerConfigsByIDs = `-- name: GetMCPServerConfigsByIDs :many +SELECT + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +FROM + mcp_server_configs +WHERE + id = ANY($1::uuid[]) +ORDER BY + display_name ASC +` + +func (q *sqlQuerier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error) { + rows, err := q.db.QueryContext(ctx, getMCPServerConfigsByIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []MCPServerConfig + for rows.Next() { + var i MCPServerConfig + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getMCPServerUserToken = `-- name: GetMCPServerUserToken :one +SELECT + id, mcp_server_config_id, user_id, access_token, access_token_key_id, refresh_token, refresh_token_key_id, token_type, expiry, created_at, updated_at +FROM + mcp_server_user_tokens +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type GetMCPServerUserTokenParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error) { + row := q.db.QueryRowContext(ctx, getMCPServerUserToken, arg.MCPServerConfigID, arg.UserID) + var i MCPServerUserToken + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.AccessToken, + &i.AccessTokenKeyID, + &i.RefreshToken, + &i.RefreshTokenKeyID, + &i.TokenType, + &i.Expiry, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getMCPServerUserTokensByUserID = `-- name: GetMCPServerUserTokensByUserID :many +SELECT + id, mcp_server_config_id, user_id, access_token, access_token_key_id, refresh_token, refresh_token_key_id, token_type, expiry, created_at, updated_at +FROM + mcp_server_user_tokens +WHERE + user_id = $1::uuid +` + +func (q *sqlQuerier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error) { + rows, err := q.db.QueryContext(ctx, getMCPServerUserTokensByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []MCPServerUserToken + for rows.Next() { + var i MCPServerUserToken + if err := rows.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.AccessToken, + &i.AccessTokenKeyID, + &i.RefreshToken, + &i.RefreshTokenKeyID, + &i.TokenType, + &i.Expiry, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertMCPServerConfig = `-- name: InsertMCPServerConfig :one +INSERT INTO mcp_server_configs ( + display_name, + slug, + description, + icon_url, + transport, + url, + auth_type, + oauth2_client_id, + oauth2_client_secret, + oauth2_client_secret_key_id, + oauth2_auth_url, + oauth2_token_url, + oauth2_scopes, + api_key_header, + api_key_value, + api_key_value_key_id, + custom_headers, + custom_headers_key_id, + tool_allow_list, + tool_deny_list, + availability, + enabled, + created_by, + updated_by +) VALUES ( + $1::text, + $2::text, + $3::text, + $4::text, + $5::text, + $6::text, + $7::text, + $8::text, + $9::text, + $10::text, + $11::text, + $12::text, + $13::text, + $14::text, + $15::text, + $16::text, + $17::text, + $18::text, + $19::text[], + $20::text[], + $21::text, + $22::boolean, + $23::uuid, + $24::uuid +) +RETURNING + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +` + +type InsertMCPServerConfigParams struct { + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` +} + +func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPServerConfigParams) (MCPServerConfig, error) { + row := q.db.QueryRowContext(ctx, insertMCPServerConfig, + arg.DisplayName, + arg.Slug, + arg.Description, + arg.IconURL, + arg.Transport, + arg.Url, + arg.AuthType, + arg.OAuth2ClientID, + arg.OAuth2ClientSecret, + arg.OAuth2ClientSecretKeyID, + arg.OAuth2AuthURL, + arg.OAuth2TokenURL, + arg.OAuth2Scopes, + arg.APIKeyHeader, + arg.APIKeyValue, + arg.APIKeyValueKeyID, + arg.CustomHeaders, + arg.CustomHeadersKeyID, + pq.Array(arg.ToolAllowList), + pq.Array(arg.ToolDenyList), + arg.Availability, + arg.Enabled, + arg.CreatedBy, + arg.UpdatedBy, + ) + var i MCPServerConfig + err := row.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateMCPServerConfig = `-- name: UpdateMCPServerConfig :one +UPDATE + mcp_server_configs +SET + display_name = $1::text, + slug = $2::text, + description = $3::text, + icon_url = $4::text, + transport = $5::text, + url = $6::text, + auth_type = $7::text, + oauth2_client_id = $8::text, + oauth2_client_secret = $9::text, + oauth2_client_secret_key_id = $10::text, + oauth2_auth_url = $11::text, + oauth2_token_url = $12::text, + oauth2_scopes = $13::text, + api_key_header = $14::text, + api_key_value = $15::text, + api_key_value_key_id = $16::text, + custom_headers = $17::text, + custom_headers_key_id = $18::text, + tool_allow_list = $19::text[], + tool_deny_list = $20::text[], + availability = $21::text, + enabled = $22::boolean, + updated_by = $23::uuid, + updated_at = NOW() +WHERE + id = $24::uuid +RETURNING + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at +` + +type UpdateMCPServerConfigParams struct { + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPServerConfigParams) (MCPServerConfig, error) { + row := q.db.QueryRowContext(ctx, updateMCPServerConfig, + arg.DisplayName, + arg.Slug, + arg.Description, + arg.IconURL, + arg.Transport, + arg.Url, + arg.AuthType, + arg.OAuth2ClientID, + arg.OAuth2ClientSecret, + arg.OAuth2ClientSecretKeyID, + arg.OAuth2AuthURL, + arg.OAuth2TokenURL, + arg.OAuth2Scopes, + arg.APIKeyHeader, + arg.APIKeyValue, + arg.APIKeyValueKeyID, + arg.CustomHeaders, + arg.CustomHeadersKeyID, + pq.Array(arg.ToolAllowList), + pq.Array(arg.ToolDenyList), + arg.Availability, + arg.Enabled, + arg.UpdatedBy, + arg.ID, + ) + var i MCPServerConfig + err := row.Scan( + &i.ID, + &i.DisplayName, + &i.Slug, + &i.Description, + &i.IconURL, + &i.Transport, + &i.Url, + &i.AuthType, + &i.OAuth2ClientID, + &i.OAuth2ClientSecret, + &i.OAuth2ClientSecretKeyID, + &i.OAuth2AuthURL, + &i.OAuth2TokenURL, + &i.OAuth2Scopes, + &i.APIKeyHeader, + &i.APIKeyValue, + &i.APIKeyValueKeyID, + &i.CustomHeaders, + &i.CustomHeadersKeyID, + pq.Array(&i.ToolAllowList), + pq.Array(&i.ToolDenyList), + &i.Availability, + &i.Enabled, + &i.CreatedBy, + &i.UpdatedBy, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertMCPServerUserToken = `-- name: UpsertMCPServerUserToken :one +INSERT INTO mcp_server_user_tokens ( + mcp_server_config_id, + user_id, + access_token, + access_token_key_id, + refresh_token, + refresh_token_key_id, + token_type, + expiry +) VALUES ( + $1::uuid, + $2::uuid, + $3::text, + $4::text, + $5::text, + $6::text, + $7::text, + $8::timestamptz +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + access_token = $3::text, + access_token_key_id = $4::text, + refresh_token = $5::text, + refresh_token_key_id = $6::text, + token_type = $7::text, + expiry = $8::timestamptz, + updated_at = NOW() +RETURNING + id, mcp_server_config_id, user_id, access_token, access_token_key_id, refresh_token, refresh_token_key_id, token_type, expiry, created_at, updated_at +` + +type UpsertMCPServerUserTokenParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AccessToken string `db:"access_token" json:"access_token"` + AccessTokenKeyID sql.NullString `db:"access_token_key_id" json:"access_token_key_id"` + RefreshToken string `db:"refresh_token" json:"refresh_token"` + RefreshTokenKeyID sql.NullString `db:"refresh_token_key_id" json:"refresh_token_key_id"` + TokenType string `db:"token_type" json:"token_type"` + Expiry sql.NullTime `db:"expiry" json:"expiry"` +} + +func (q *sqlQuerier) UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error) { + row := q.db.QueryRowContext(ctx, upsertMCPServerUserToken, + arg.MCPServerConfigID, + arg.UserID, + arg.AccessToken, + arg.AccessTokenKeyID, + arg.RefreshToken, + arg.RefreshTokenKeyID, + arg.TokenType, + arg.Expiry, + ) + var i MCPServerUserToken + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.AccessToken, + &i.AccessTokenKeyID, + &i.RefreshToken, + &i.RefreshTokenKeyID, + &i.TokenType, + &i.Expiry, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const acquireNotificationMessages = `-- name: AcquireNotificationMessages :many WITH acquired AS ( UPDATE @@ -15997,13 +16846,13 @@ SELECT type GetDefaultProxyConfigRow struct { DisplayName string `db:"display_name" json:"display_name"` - IconUrl string `db:"icon_url" json:"icon_url"` + IconURL string `db:"icon_url" json:"icon_url"` } func (q *sqlQuerier) GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error) { row := q.db.QueryRowContext(ctx, getDefaultProxyConfig) var i GetDefaultProxyConfigRow - err := row.Scan(&i.DisplayName, &i.IconUrl) + err := row.Scan(&i.DisplayName, &i.IconURL) return i, err } @@ -16204,14 +17053,14 @@ DO UPDATE SET value = EXCLUDED.value WHERE site_configs.key = EXCLUDED.key type UpsertDefaultProxyParams struct { DisplayName string `db:"display_name" json:"display_name"` - IconUrl string `db:"icon_url" json:"icon_url"` + IconURL string `db:"icon_url" json:"icon_url"` } // The default proxy is implied and not actually stored in the database. // So we need to store it's configuration here for display purposes. // The functional values are immutable and controlled implicitly. func (q *sqlQuerier) UpsertDefaultProxy(ctx context.Context, arg UpsertDefaultProxyParams) error { - _, err := q.db.ExecContext(ctx, upsertDefaultProxy, arg.DisplayName, arg.IconUrl) + _, err := q.db.ExecContext(ctx, upsertDefaultProxy, arg.DisplayName, arg.IconURL) return err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index d6d1e37653..a6b405a576 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -180,7 +180,8 @@ INSERT INTO chats ( root_chat_id, last_model_config_id, title, - mode + mode, + mcp_server_ids ) VALUES ( @owner_id::uuid, sqlc.narg('workspace_id')::uuid, @@ -188,7 +189,8 @@ INSERT INTO chats ( sqlc.narg('root_chat_id')::uuid, @last_model_config_id::uuid, @title::text, - sqlc.narg('mode')::chat_mode + sqlc.narg('mode')::chat_mode, + COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]) ) RETURNING *; @@ -295,6 +297,17 @@ WHERE RETURNING *; +-- name: UpdateChatMCPServerIDs :one +UPDATE + chats +SET + mcp_server_ids = @mcp_server_ids::uuid[], + updated_at = NOW() +WHERE + id = @id::uuid +RETURNING + *; + -- name: AcquireChats :many -- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED -- to prevent multiple replicas from acquiring the same chat. diff --git a/coderd/database/queries/mcpserverconfigs.sql b/coderd/database/queries/mcpserverconfigs.sql new file mode 100644 index 0000000000..28cf83b22a --- /dev/null +++ b/coderd/database/queries/mcpserverconfigs.sql @@ -0,0 +1,213 @@ +-- name: GetMCPServerConfigByID :one +SELECT + * +FROM + mcp_server_configs +WHERE + id = @id::uuid; + +-- name: GetMCPServerConfigBySlug :one +SELECT + * +FROM + mcp_server_configs +WHERE + slug = @slug::text; + +-- name: GetMCPServerConfigs :many +SELECT + * +FROM + mcp_server_configs +ORDER BY + display_name ASC; + +-- name: GetEnabledMCPServerConfigs :many +SELECT + * +FROM + mcp_server_configs +WHERE + enabled = TRUE +ORDER BY + display_name ASC; + +-- name: GetMCPServerConfigsByIDs :many +SELECT + * +FROM + mcp_server_configs +WHERE + id = ANY(@ids::uuid[]) +ORDER BY + display_name ASC; + +-- name: GetForcedMCPServerConfigs :many +SELECT + * +FROM + mcp_server_configs +WHERE + enabled = TRUE + AND availability = 'force_on' +ORDER BY + display_name ASC; + +-- name: InsertMCPServerConfig :one +INSERT INTO mcp_server_configs ( + display_name, + slug, + description, + icon_url, + transport, + url, + auth_type, + oauth2_client_id, + oauth2_client_secret, + oauth2_client_secret_key_id, + oauth2_auth_url, + oauth2_token_url, + oauth2_scopes, + api_key_header, + api_key_value, + api_key_value_key_id, + custom_headers, + custom_headers_key_id, + tool_allow_list, + tool_deny_list, + availability, + enabled, + created_by, + updated_by +) VALUES ( + @display_name::text, + @slug::text, + @description::text, + @icon_url::text, + @transport::text, + @url::text, + @auth_type::text, + @oauth2_client_id::text, + @oauth2_client_secret::text, + sqlc.narg('oauth2_client_secret_key_id')::text, + @oauth2_auth_url::text, + @oauth2_token_url::text, + @oauth2_scopes::text, + @api_key_header::text, + @api_key_value::text, + sqlc.narg('api_key_value_key_id')::text, + @custom_headers::text, + sqlc.narg('custom_headers_key_id')::text, + @tool_allow_list::text[], + @tool_deny_list::text[], + @availability::text, + @enabled::boolean, + @created_by::uuid, + @updated_by::uuid +) +RETURNING + *; + +-- name: UpdateMCPServerConfig :one +UPDATE + mcp_server_configs +SET + display_name = @display_name::text, + slug = @slug::text, + description = @description::text, + icon_url = @icon_url::text, + transport = @transport::text, + url = @url::text, + auth_type = @auth_type::text, + oauth2_client_id = @oauth2_client_id::text, + oauth2_client_secret = @oauth2_client_secret::text, + oauth2_client_secret_key_id = sqlc.narg('oauth2_client_secret_key_id')::text, + oauth2_auth_url = @oauth2_auth_url::text, + oauth2_token_url = @oauth2_token_url::text, + oauth2_scopes = @oauth2_scopes::text, + api_key_header = @api_key_header::text, + api_key_value = @api_key_value::text, + api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text, + custom_headers = @custom_headers::text, + custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text, + tool_allow_list = @tool_allow_list::text[], + tool_deny_list = @tool_deny_list::text[], + availability = @availability::text, + enabled = @enabled::boolean, + updated_by = @updated_by::uuid, + updated_at = NOW() +WHERE + id = @id::uuid +RETURNING + *; + +-- name: DeleteMCPServerConfigByID :exec +DELETE FROM + mcp_server_configs +WHERE + id = @id::uuid; + +-- name: GetMCPServerUserToken :one +SELECT + * +FROM + mcp_server_user_tokens +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + +-- name: GetMCPServerUserTokensByUserID :many +SELECT + * +FROM + mcp_server_user_tokens +WHERE + user_id = @user_id::uuid; + +-- name: UpsertMCPServerUserToken :one +INSERT INTO mcp_server_user_tokens ( + mcp_server_config_id, + user_id, + access_token, + access_token_key_id, + refresh_token, + refresh_token_key_id, + token_type, + expiry +) VALUES ( + @mcp_server_config_id::uuid, + @user_id::uuid, + @access_token::text, + sqlc.narg('access_token_key_id')::text, + @refresh_token::text, + sqlc.narg('refresh_token_key_id')::text, + @token_type::text, + sqlc.narg('expiry')::timestamptz +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + access_token = @access_token::text, + access_token_key_id = sqlc.narg('access_token_key_id')::text, + refresh_token = @refresh_token::text, + refresh_token_key_id = sqlc.narg('refresh_token_key_id')::text, + token_type = @token_type::text, + expiry = sqlc.narg('expiry')::timestamptz, + updated_at = NOW() +RETURNING + *; + +-- name: DeleteMCPServerUserToken :exec +DELETE FROM + mcp_server_user_tokens +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + +-- name: CleanupDeletedMCPServerIDsFromChats :exec +UPDATE chats +SET mcp_server_ids = ( + SELECT COALESCE(array_agg(sid), '{}') + FROM unnest(chats.mcp_server_ids) AS sid + WHERE sid IN (SELECT id FROM mcp_server_configs) +) +WHERE mcp_server_ids != '{}' + AND NOT (mcp_server_ids <@ COALESCE((SELECT array_agg(id) FROM mcp_server_configs), '{}')); diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 72c968fcd8..a6d5396b44 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -236,6 +236,28 @@ sql: aibridge_token_usage: AIBridgeTokenUsage aibridge_user_prompt: AIBridgeUserPrompt aibridge_model_thought: AIBridgeModelThought + mcp_server_config: MCPServerConfig + mcp_server_configs: MCPServerConfigs + mcp_server_user_token: MCPServerUserToken + mcp_server_user_tokens: MCPServerUserTokens + mcp_server_tool_snapshot: MCPServerToolSnapshot + mcp_server_tool_snapshots: MCPServerToolSnapshots + mcp_server_config_id: MCPServerConfigID + mcp_server_ids: MCPServerIDs + icon_url: IconURL + oauth2_client_id: OAuth2ClientID + oauth2_client_secret: OAuth2ClientSecret + oauth2_client_secret_key_id: OAuth2ClientSecretKeyID + oauth2_auth_url: OAuth2AuthURL + oauth2_token_url: OAuth2TokenURL + oauth2_scopes: OAuth2Scopes + api_key_header: APIKeyHeader + api_key_value: APIKeyValue + api_key_value_key_id: APIKeyValueKeyID + custom_headers_key_id: CustomHeadersKeyID + tools_json: ToolsJSON + access_token_key_id: AccessTokenKeyID + refresh_token_key_id: RefreshTokenKeyID rules: - name: do-not-use-public-schema-in-queries message: "do not use public schema in queries" diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 35f40d7c5f..8a123be0cb 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -42,6 +42,10 @@ const ( UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); + UniqueMcpServerConfigsPkey UniqueConstraint = "mcp_server_configs_pkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id); + UniqueMcpServerConfigsSlugKey UniqueConstraint = "mcp_server_configs_slug_key" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug); + UniqueMcpServerUserTokensMcpServerConfigIDUserIDKey UniqueConstraint = "mcp_server_user_tokens_mcp_server_config_id_user_id_key" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); + UniqueMcpServerUserTokensPkey UniqueConstraint = "mcp_server_user_tokens_pkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id); UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id); UniqueNotificationPreferencesPkey UniqueConstraint = "notification_preferences_pkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_pkey PRIMARY KEY (user_id, notification_template_id); UniqueNotificationReportGeneratorLogsPkey UniqueConstraint = "notification_report_generator_logs_pkey" // ALTER TABLE ONLY notification_report_generator_logs ADD CONSTRAINT notification_report_generator_logs_pkey PRIMARY KEY (notification_template_id); diff --git a/coderd/mcp.go b/coderd/mcp.go new file mode 100644 index 0000000000..39b41919ed --- /dev/null +++ b/coderd/mcp.go @@ -0,0 +1,921 @@ +package coderd + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "golang.org/x/oauth2" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/codersdk" +) + +// @Summary List MCP server configs +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) listMCPServerConfigs(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + // Admin users can see all MCP server configs (including disabled + // ones) for management purposes. Non-admin users see only enabled + // configs, which is sufficient for using the chat feature. + isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) + + var configs []database.MCPServerConfig + var err error + if isAdmin { + configs, err = api.Database.GetMCPServerConfigs(ctx) + } else { + //nolint:gocritic // All authenticated users need to read enabled MCP server configs to use the chat feature. + configs, err = api.Database.GetEnabledMCPServerConfigs(dbauthz.AsSystemRestricted(ctx)) + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to list MCP server configs.", + Detail: err.Error(), + }) + return + } + + // Look up the calling user's OAuth2 tokens so we can populate + // auth_connected per server. + //nolint:gocritic // Need to check user tokens across all servers. + userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get user tokens.", + Detail: err.Error(), + }) + return + } + tokenMap := make(map[uuid.UUID]bool, len(userTokens)) + for _, t := range userTokens { + tokenMap[t.MCPServerConfigID] = true + } + + resp := make([]codersdk.MCPServerConfig, 0, len(configs)) + for _, config := range configs { + var sdkConfig codersdk.MCPServerConfig + if isAdmin { + sdkConfig = convertMCPServerConfig(config) + } else { + sdkConfig = convertMCPServerConfigRedacted(config) + } + if config.AuthType == "oauth2" { + sdkConfig.AuthConnected = tokenMap[config.ID] + } else { + sdkConfig.AuthConnected = true + } + resp = append(resp, sdkConfig) + } + + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + +// @Summary Create MCP server config +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.Forbidden(rw) + return + } + + var req codersdk.CreateMCPServerConfigRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Validate auth-type-dependent fields. + switch req.AuthType { + case "oauth2": + if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.", + }) + return + } + case "api_key": + if req.APIKeyHeader == "" || req.APIKeyValue == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key auth type requires api_key_header and api_key_value.", + }) + return + } + case "custom_headers": + if len(req.CustomHeaders) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Custom headers auth type requires at least one custom header.", + }) + return + } + } + + customHeadersJSON, err := marshalCustomHeaders(req.CustomHeaders) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid custom headers.", + Detail: err.Error(), + }) + return + } + + inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{ + DisplayName: strings.TrimSpace(req.DisplayName), + Slug: strings.TrimSpace(req.Slug), + Description: strings.TrimSpace(req.Description), + IconURL: strings.TrimSpace(req.IconURL), + Transport: strings.TrimSpace(req.Transport), + Url: strings.TrimSpace(req.URL), + AuthType: strings.TrimSpace(req.AuthType), + OAuth2ClientID: strings.TrimSpace(req.OAuth2ClientID), + OAuth2ClientSecret: strings.TrimSpace(req.OAuth2ClientSecret), + OAuth2ClientSecretKeyID: sql.NullString{}, + OAuth2AuthURL: strings.TrimSpace(req.OAuth2AuthURL), + OAuth2TokenURL: strings.TrimSpace(req.OAuth2TokenURL), + OAuth2Scopes: strings.TrimSpace(req.OAuth2Scopes), + APIKeyHeader: strings.TrimSpace(req.APIKeyHeader), + APIKeyValue: strings.TrimSpace(req.APIKeyValue), + APIKeyValueKeyID: sql.NullString{}, + CustomHeaders: customHeadersJSON, + CustomHeadersKeyID: sql.NullString{}, + ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), + ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), + Availability: strings.TrimSpace(req.Availability), + Enabled: req.Enabled, + CreatedBy: apiKey.UserID, + UpdatedBy: apiKey.UserID, + }) + if err != nil { + switch { + case database.IsUniqueViolation(err): + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "MCP server config already exists.", + Detail: err.Error(), + }) + return + case database.IsCheckViolation(err): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid MCP server config.", + Detail: err.Error(), + }) + return + default: + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to create MCP server config.", + Detail: err.Error(), + }) + return + } + } + + httpapi.Write(ctx, rw, http.StatusCreated, convertMCPServerConfig(inserted)) +} + +// @Summary Get MCP server config +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) getMCPServerConfig(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) + + var config database.MCPServerConfig + var err error + if isAdmin { + config, err = api.Database.GetMCPServerConfigByID(ctx, mcpServerID) + } else { + //nolint:gocritic // All authenticated users can view enabled MCP server configs. + config, err = api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID) + if err == nil && !config.Enabled { + httpapi.ResourceNotFound(rw) + return + } + } + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get MCP server config.", + Detail: err.Error(), + }) + return + } + + var sdkConfig codersdk.MCPServerConfig + if isAdmin { + sdkConfig = convertMCPServerConfig(config) + } else { + sdkConfig = convertMCPServerConfigRedacted(config) + } + + // Populate AuthConnected for the calling user. + if config.AuthType == "oauth2" { + //nolint:gocritic // Need to check user token for this server. + userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get user tokens.", + Detail: err.Error(), + }) + return + } + for _, t := range userTokens { + if t.MCPServerConfigID == config.ID { + sdkConfig.AuthConnected = true + break + } + } + } else { + sdkConfig.AuthConnected = true + } + + httpapi.Write(ctx, rw, http.StatusOK, sdkConfig) +} + +// @Summary Update MCP server config +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.Forbidden(rw) + return + } + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + var req codersdk.UpdateMCPServerConfigRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Pre-validate custom headers before entering the transaction. + var customHeadersJSON string + if req.CustomHeaders != nil { + var chErr error + customHeadersJSON, chErr = marshalCustomHeaders(*req.CustomHeaders) + if chErr != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid custom headers.", + Detail: chErr.Error(), + }) + return + } + } + + var updated database.MCPServerConfig + err := api.Database.InTx(func(tx database.Store) error { + existing, err := tx.GetMCPServerConfigByID(ctx, mcpServerID) + if err != nil { + return err + } + + displayName := existing.DisplayName + if req.DisplayName != nil { + displayName = strings.TrimSpace(*req.DisplayName) + } + + slug := existing.Slug + if req.Slug != nil { + slug = strings.TrimSpace(*req.Slug) + } + + description := existing.Description + if req.Description != nil { + description = strings.TrimSpace(*req.Description) + } + + iconURL := existing.IconURL + if req.IconURL != nil { + iconURL = strings.TrimSpace(*req.IconURL) + } + + transport := existing.Transport + if req.Transport != nil { + transport = strings.TrimSpace(*req.Transport) + } + + serverURL := existing.Url + if req.URL != nil { + serverURL = strings.TrimSpace(*req.URL) + } + + authType := existing.AuthType + if req.AuthType != nil { + authType = strings.TrimSpace(*req.AuthType) + } + + oauth2ClientID := existing.OAuth2ClientID + if req.OAuth2ClientID != nil { + oauth2ClientID = strings.TrimSpace(*req.OAuth2ClientID) + } + + oauth2ClientSecret := existing.OAuth2ClientSecret + oauth2ClientSecretKeyID := existing.OAuth2ClientSecretKeyID + if req.OAuth2ClientSecret != nil { + oauth2ClientSecret = strings.TrimSpace(*req.OAuth2ClientSecret) + // Clear the key ID when the secret is explicitly updated. + oauth2ClientSecretKeyID = sql.NullString{} + } + + oauth2AuthURL := existing.OAuth2AuthURL + if req.OAuth2AuthURL != nil { + oauth2AuthURL = strings.TrimSpace(*req.OAuth2AuthURL) + } + + oauth2TokenURL := existing.OAuth2TokenURL + if req.OAuth2TokenURL != nil { + oauth2TokenURL = strings.TrimSpace(*req.OAuth2TokenURL) + } + + oauth2Scopes := existing.OAuth2Scopes + if req.OAuth2Scopes != nil { + oauth2Scopes = strings.TrimSpace(*req.OAuth2Scopes) + } + + apiKeyHeader := existing.APIKeyHeader + if req.APIKeyHeader != nil { + apiKeyHeader = strings.TrimSpace(*req.APIKeyHeader) + } + + apiKeyValue := existing.APIKeyValue + apiKeyValueKeyID := existing.APIKeyValueKeyID + if req.APIKeyValue != nil { + apiKeyValue = strings.TrimSpace(*req.APIKeyValue) + // Clear the key ID when the value is explicitly updated. + apiKeyValueKeyID = sql.NullString{} + } + + customHeaders := existing.CustomHeaders + customHeadersKeyID := existing.CustomHeadersKeyID + if req.CustomHeaders != nil { + customHeaders = customHeadersJSON + // Clear the key ID when headers are explicitly updated. + customHeadersKeyID = sql.NullString{} + } + + toolAllowList := existing.ToolAllowList + if req.ToolAllowList != nil { + toolAllowList = coalesceStringSlice(trimStringSlice(*req.ToolAllowList)) + } + + toolDenyList := existing.ToolDenyList + if req.ToolDenyList != nil { + toolDenyList = coalesceStringSlice(trimStringSlice(*req.ToolDenyList)) + } + + availability := existing.Availability + if req.Availability != nil { + availability = strings.TrimSpace(*req.Availability) + } + + enabled := existing.Enabled + if req.Enabled != nil { + enabled = *req.Enabled + } + + // When auth_type changes, clear fields belonging to the + // previous auth type so stale secrets don't persist. + if authType != existing.AuthType { + switch authType { + case "none": + oauth2ClientID = "" + oauth2ClientSecret = "" + oauth2ClientSecretKeyID = sql.NullString{} + oauth2AuthURL = "" + oauth2TokenURL = "" + oauth2Scopes = "" + apiKeyHeader = "" + apiKeyValue = "" + apiKeyValueKeyID = sql.NullString{} + customHeaders = "{}" + customHeadersKeyID = sql.NullString{} + case "oauth2": + apiKeyHeader = "" + apiKeyValue = "" + apiKeyValueKeyID = sql.NullString{} + customHeaders = "{}" + customHeadersKeyID = sql.NullString{} + case "api_key": + oauth2ClientID = "" + oauth2ClientSecret = "" + oauth2ClientSecretKeyID = sql.NullString{} + oauth2AuthURL = "" + oauth2TokenURL = "" + oauth2Scopes = "" + customHeaders = "{}" + customHeadersKeyID = sql.NullString{} + case "custom_headers": + oauth2ClientID = "" + oauth2ClientSecret = "" + oauth2ClientSecretKeyID = sql.NullString{} + oauth2AuthURL = "" + oauth2TokenURL = "" + oauth2Scopes = "" + apiKeyHeader = "" + apiKeyValue = "" + apiKeyValueKeyID = sql.NullString{} + } + } + + updated, err = tx.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{ + DisplayName: displayName, + Slug: slug, + Description: description, + IconURL: iconURL, + Transport: transport, + Url: serverURL, + AuthType: authType, + OAuth2ClientID: oauth2ClientID, + OAuth2ClientSecret: oauth2ClientSecret, + OAuth2ClientSecretKeyID: oauth2ClientSecretKeyID, + OAuth2AuthURL: oauth2AuthURL, + OAuth2TokenURL: oauth2TokenURL, + OAuth2Scopes: oauth2Scopes, + APIKeyHeader: apiKeyHeader, + APIKeyValue: apiKeyValue, + APIKeyValueKeyID: apiKeyValueKeyID, + CustomHeaders: customHeaders, + CustomHeadersKeyID: customHeadersKeyID, + ToolAllowList: toolAllowList, + ToolDenyList: toolDenyList, + Availability: availability, + Enabled: enabled, + UpdatedBy: apiKey.UserID, + ID: existing.ID, + }) + return err + }, nil) + if err != nil { + switch { + case httpapi.Is404Error(err): + httpapi.ResourceNotFound(rw) + return + case database.IsUniqueViolation(err): + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "MCP server config slug already exists.", + Detail: err.Error(), + }) + return + case database.IsCheckViolation(err): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid MCP server config.", + Detail: err.Error(), + }) + return + default: + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update MCP server config.", + Detail: err.Error(), + }) + return + } + } + + httpapi.Write(ctx, rw, http.StatusOK, convertMCPServerConfig(updated)) +} + +// @Summary Delete MCP server config +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +func (api *API) deleteMCPServerConfig(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.Forbidden(rw) + return + } + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + if _, err := api.Database.GetMCPServerConfigByID(ctx, mcpServerID); err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get MCP server config.", + Detail: err.Error(), + }) + return + } + + if err := api.Database.DeleteMCPServerConfigByID(ctx, mcpServerID); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to delete MCP server config.", + Detail: err.Error(), + }) + return + } + + rw.WriteHeader(http.StatusNoContent) +} + +// @Summary Initiate MCP server OAuth2 connect +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// Redirects the user to the MCP server's OAuth2 authorization URL. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) mcpServerOAuth2Connect(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + //nolint:gocritic // Any authenticated user can initiate OAuth2 for an enabled MCP server. + config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get MCP server config.", + Detail: err.Error(), + }) + return + } + + if !config.Enabled { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "MCP server is not enabled.", + }) + return + } + + if config.AuthType != "oauth2" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "MCP server does not use OAuth2 authentication.", + }) + return + } + + if config.OAuth2AuthURL == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "MCP server OAuth2 authorization URL is not configured.", + }) + return + } + + // Build the authorization URL. The frontend opens this in a popup. + // The callback URL is on our server; after the exchange we store + // the token and close the popup. + state := uuid.New().String() + http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{ + Name: "mcp_oauth2_state_" + config.ID.String(), + Value: state, + Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID), + MaxAge: 600, // 10 minutes + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + })) + + oauth2Config := &oauth2.Config{ + ClientID: config.OAuth2ClientID, + ClientSecret: config.OAuth2ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: config.OAuth2AuthURL, + TokenURL: config.OAuth2TokenURL, + }, + RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID), + } + var scopes []string + if config.OAuth2Scopes != "" { + scopes = strings.Split(config.OAuth2Scopes, " ") + } + oauth2Config.Scopes = scopes + authURL := oauth2Config.AuthCodeURL(state) + http.Redirect(rw, r, authURL, http.StatusTemporaryRedirect) +} + +// @Summary Handle MCP server OAuth2 callback +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// Exchanges the authorization code for tokens and stores them. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) mcpServerOAuth2Callback(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + //nolint:gocritic // Any authenticated user can complete OAuth2 for an enabled MCP server. + config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get MCP server config.", + Detail: err.Error(), + }) + return + } + + if !config.Enabled { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "MCP server is not enabled.", + }) + return + } + + if config.AuthType != "oauth2" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "MCP server does not use OAuth2 authentication.", + }) + return + } + + // Check if the OAuth2 provider returned an error (e.g., user + // denied consent). + if oauthError := r.URL.Query().Get("error"); oauthError != "" { + desc := r.URL.Query().Get("error_description") + if desc == "" { + desc = oauthError + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "OAuth2 provider returned an error.", + Detail: desc, + }) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing authorization code.", + }) + return + } + + // Validate the state parameter for CSRF protection. + expectedState := "" + if cookie, err := r.Cookie("mcp_oauth2_state_" + config.ID.String()); err == nil { + expectedState = cookie.Value + } + actualState := r.URL.Query().Get("state") + if expectedState == "" || actualState != expectedState { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid or missing OAuth2 state parameter.", + }) + return + } + // Clear the state cookie. + http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{ + Name: "mcp_oauth2_state_" + config.ID.String(), + Value: "", + Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID), + MaxAge: -1, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + })) + + // Exchange the authorization code for tokens. + oauth2Config := &oauth2.Config{ + ClientID: config.OAuth2ClientID, + ClientSecret: config.OAuth2ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: config.OAuth2AuthURL, + TokenURL: config.OAuth2TokenURL, + }, + RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID), + } + var scopes []string + if config.OAuth2Scopes != "" { + scopes = strings.Split(config.OAuth2Scopes, " ") + } + oauth2Config.Scopes = scopes + + // Use the deployment's HTTP client for the token exchange to + // respect proxy settings and avoid using http.DefaultClient. + exchangeCtx := context.WithValue(ctx, oauth2.HTTPClient, api.HTTPClient) + token, err := oauth2Config.Exchange(exchangeCtx, code) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadGateway, codersdk.Response{ + Message: "Failed to exchange authorization code for token.", + Detail: "The OAuth2 token exchange with the upstream provider failed.", + }) + return + } + + // Store the token for the user. + refreshToken := "" + if token.RefreshToken != "" { + refreshToken = token.RefreshToken + } + + var expiry sql.NullTime + if !token.Expiry.IsZero() { + expiry = sql.NullTime{Time: token.Expiry, Valid: true} + } + + //nolint:gocritic // Users store their own tokens. + _, err = api.Database.UpsertMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.UpsertMCPServerUserTokenParams{ + MCPServerConfigID: mcpServerID, + UserID: apiKey.UserID, + AccessToken: token.AccessToken, + AccessTokenKeyID: sql.NullString{}, + RefreshToken: refreshToken, + RefreshTokenKeyID: sql.NullString{}, + TokenType: token.TokenType, + Expiry: expiry, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to store OAuth2 token.", + Detail: err.Error(), + }) + return + } + + // Respond with a simple HTML page that closes the popup window. + rw.Header().Set("Content-Security-Policy", "default-src 'none'; script-src 'unsafe-inline'") + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte(``)) +} + +// @Summary Disconnect MCP server OAuth2 token +// @x-apidocgen {"skip": true} +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// Removes the user's stored OAuth2 token for an MCP server. +func (api *API) mcpServerOAuth2Disconnect(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + mcpServerID, ok := parseMCPServerConfigID(rw, r) + if !ok { + return + } + + //nolint:gocritic // Users manage their own tokens. + err := api.Database.DeleteMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.DeleteMCPServerUserTokenParams{ + MCPServerConfigID: mcpServerID, + UserID: apiKey.UserID, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to disconnect OAuth2 token.", + Detail: err.Error(), + }) + return + } + + rw.WriteHeader(http.StatusNoContent) +} + +// parseMCPServerConfigID extracts the MCP server config UUID from the +// "mcpServer" path parameter. +func parseMCPServerConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { + mcpServerID, err := uuid.Parse(chi.URLParam(r, "mcpServer")) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid MCP server config ID.", + Detail: err.Error(), + }) + return uuid.Nil, false + } + return mcpServerID, true +} + +// convertMCPServerConfig converts a database MCP server config to the +// SDK type. Secrets are never returned; only has_* booleans are set. +// Admin-only fields (OAuth2 client ID, auth URLs, etc.) are included. +func convertMCPServerConfig(config database.MCPServerConfig) codersdk.MCPServerConfig { + return codersdk.MCPServerConfig{ + ID: config.ID, + DisplayName: config.DisplayName, + Slug: config.Slug, + Description: config.Description, + IconURL: config.IconURL, + + Transport: config.Transport, + URL: config.Url, + + AuthType: config.AuthType, + OAuth2ClientID: config.OAuth2ClientID, + HasOAuth2Secret: config.OAuth2ClientSecret != "", + OAuth2AuthURL: config.OAuth2AuthURL, + OAuth2TokenURL: config.OAuth2TokenURL, + OAuth2Scopes: config.OAuth2Scopes, + + APIKeyHeader: config.APIKeyHeader, + HasAPIKey: config.APIKeyValue != "", + + HasCustomHeaders: len(config.CustomHeaders) > 0 && config.CustomHeaders != "{}", + + ToolAllowList: coalesceStringSlice(config.ToolAllowList), + ToolDenyList: coalesceStringSlice(config.ToolDenyList), + + Availability: config.Availability, + + Enabled: config.Enabled, + CreatedAt: config.CreatedAt, + UpdatedAt: config.UpdatedAt, + } +} + +// convertMCPServerConfigRedacted is the same as convertMCPServerConfig +// but strips admin-only fields (OAuth2 details, API key header) for +// non-admin callers. +func convertMCPServerConfigRedacted(config database.MCPServerConfig) codersdk.MCPServerConfig { + c := convertMCPServerConfig(config) + c.URL = "" + c.Transport = "" + c.OAuth2ClientID = "" + c.OAuth2AuthURL = "" + c.OAuth2TokenURL = "" + c.OAuth2Scopes = "" + c.APIKeyHeader = "" + return c +} + +// marshalCustomHeaders encodes a map of custom headers to JSON for +// database storage. A nil map produces an empty JSON object. +func marshalCustomHeaders(headers map[string]string) (string, error) { + if headers == nil { + return "{}", nil + } + encoded, err := json.Marshal(headers) + if err != nil { + return "", err + } + return string(encoded), nil +} + +// trimStringSlice trims whitespace from each element and drops empty +// strings. +func trimStringSlice(ss []string) []string { + if ss == nil { + return nil + } + out := make([]string, 0, len(ss)) + for _, s := range ss { + if trimmed := strings.TrimSpace(s); trimmed != "" { + out = append(out, trimmed) + } + } + return out +} + +// coalesceStringSlice returns ss if non-nil, otherwise an empty +// non-nil slice. This prevents pq.Array from sending NULL for +// NOT NULL text[] columns. +func coalesceStringSlice(ss []string) []string { + if ss == nil { + return []string{} + } + return ss +} diff --git a/coderd/mcp_test.go b/coderd/mcp_test.go new file mode 100644 index 0000000000..0e61f09d2c --- /dev/null +++ b/coderd/mcp_test.go @@ -0,0 +1,489 @@ +package coderd_test + +import ( + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// mcpDeploymentValues returns deployment values with the agents +// experiment enabled, which is required by the MCP server config +// endpoints. +func mcpDeploymentValues(t testing.TB) *codersdk.DeploymentValues { + t.Helper() + + values := coderdtest.DeploymentValues(t) + values.Experiments = []string{string(codersdk.ExperimentAgents)} + return values +} + +// newMCPClient creates a test server with the agents experiment +// enabled and returns the admin client. +func newMCPClient(t testing.TB) *codersdk.Client { + t.Helper() + + return coderdtest.New(t, &coderdtest.Options{ + DeploymentValues: mcpDeploymentValues(t), + }) +} + +// createMCPServerConfig is a helper that creates a minimal enabled +// MCP server config with auth_type=none. +func createMCPServerConfig(t testing.TB, client *codersdk.Client, slug string, enabled bool) codersdk.MCPServerConfig { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + config, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Test Server " + slug, + Slug: slug, + Description: "A test MCP server.", + IconURL: "https://example.com/icon.png", + Transport: "streamable_http", + URL: "https://mcp.example.com/" + slug, + AuthType: "none", + Availability: "default_on", + Enabled: enabled, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + return config +} + +func TestMCPServerConfigsCRUD(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + // Create a config with all fields populated including OAuth2 + // secrets so we can verify they are not leaked. + created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "My MCP Server", + Slug: "my-mcp-server", + Description: "Integration test server.", + IconURL: "https://example.com/icon.png", + Transport: "streamable_http", + URL: "https://mcp.example.com/v1", + AuthType: "oauth2", + OAuth2ClientID: "client-id-123", + OAuth2ClientSecret: "super-secret-value", + OAuth2AuthURL: "https://auth.example.com/authorize", + OAuth2TokenURL: "https://auth.example.com/token", + OAuth2Scopes: "read write", + Availability: "default_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, created.ID) + require.Equal(t, "My MCP Server", created.DisplayName) + require.Equal(t, "my-mcp-server", created.Slug) + require.Equal(t, "Integration test server.", created.Description) + require.Equal(t, "streamable_http", created.Transport) + require.Equal(t, "https://mcp.example.com/v1", created.URL) + require.Equal(t, "oauth2", created.AuthType) + require.Equal(t, "client-id-123", created.OAuth2ClientID) + require.Equal(t, "default_on", created.Availability) + require.True(t, created.Enabled) + + // Verify the secret is indicated but never returned. + require.True(t, created.HasOAuth2Secret) + + // Verify the config appears in the list. + configs, err := client.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, configs, 1) + require.Equal(t, created.ID, configs[0].ID) + require.True(t, configs[0].HasOAuth2Secret) + + // Update display name and availability. + newName := "Renamed Server" + newAvail := "force_on" + updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{ + DisplayName: &newName, + Availability: &newAvail, + }) + require.NoError(t, err) + require.Equal(t, "Renamed Server", updated.DisplayName) + require.Equal(t, "force_on", updated.Availability) + // Unchanged fields should remain the same. + require.Equal(t, "my-mcp-server", updated.Slug) + require.Equal(t, "oauth2", updated.AuthType) + + // Verify the update took effect through the list. + configs, err = client.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, configs, 1) + require.Equal(t, "Renamed Server", configs[0].DisplayName) + require.Equal(t, "force_on", configs[0].Availability) + + // Delete it. + err = client.DeleteMCPServerConfig(ctx, created.ID) + require.NoError(t, err) + + // Verify it's gone. + configs, err = client.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Empty(t, configs) +} + +func TestMCPServerConfigsNonAdmin(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newMCPClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient) + memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + // Admin creates two configs: one enabled, one disabled. + _ = createMCPServerConfig(t, adminClient, "enabled-server", true) + _ = createMCPServerConfig(t, adminClient, "disabled-server", false) + + // Admin sees both. + adminConfigs, err := adminClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, adminConfigs, 2) + + // Regular user sees only the enabled one. + memberConfigs, err := memberClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, memberConfigs, 1) + require.Equal(t, "enabled-server", memberConfigs[0].Slug) +} + +// TestMCPServerConfigsSecretsNeverLeaked is a load-bearing test that +// ensures secret fields (OAuth2 client secret, API key value, custom +// headers) are never present in API responses for any caller. If this +// test fails, it means a code change accidentally started exposing +// secrets. See: https://github.com/coder/coder/pull/23227#discussion_r2959461109 +func TestMCPServerConfigsSecretsNeverLeaked(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newMCPClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient) + memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + // Create a config with ALL secret fields populated. + created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Secrets Test", + Slug: "secrets-test", + Transport: "streamable_http", + URL: "https://mcp.example.com/secrets", + AuthType: "oauth2", + OAuth2ClientID: "client-id-secret-test", + OAuth2ClientSecret: "THIS-IS-A-SECRET-VALUE", + OAuth2AuthURL: "https://auth.example.com/authorize", + OAuth2TokenURL: "https://auth.example.com/token", + OAuth2Scopes: "read write", + APIKeyHeader: "X-Api-Key", + APIKeyValue: "THIS-IS-A-SECRET-API-KEY", + CustomHeaders: map[string]string{"X-Custom": "THIS-IS-A-SECRET-HEADER"}, + Availability: "default_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + + // The sentinel values we must never see in any JSON response. + secrets := []string{ + "THIS-IS-A-SECRET-VALUE", + "THIS-IS-A-SECRET-API-KEY", + "THIS-IS-A-SECRET-HEADER", + } + + assertNoSecrets := func(t *testing.T, label string, v interface{}) { + t.Helper() + data, err := json.Marshal(v) + require.NoError(t, err) + jsonStr := string(data) + for _, secret := range secrets { + assert.False(t, strings.Contains(jsonStr, secret), + "%s: JSON response contains secret %q", label, secret) + } + } + + // Verify the create response doesn't leak secrets. + assertNoSecrets(t, "admin create response", created) + + // Verify boolean indicators are set correctly. + require.True(t, created.HasOAuth2Secret, "HasOAuth2Secret should be true") + require.True(t, created.HasAPIKey, "HasAPIKey should be true") + require.True(t, created.HasCustomHeaders, "HasCustomHeaders should be true") + + // Admin list endpoint. + adminConfigs, err := adminClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.NotEmpty(t, adminConfigs) + for _, cfg := range adminConfigs { + assertNoSecrets(t, "admin list", cfg) + } + + // Admin get-by-ID endpoint. + adminSingle, err := adminClient.MCPServerConfigByID(ctx, created.ID) + require.NoError(t, err) + assertNoSecrets(t, "admin get-by-id", adminSingle) + + // Non-admin list endpoint. + memberConfigs, err := memberClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.NotEmpty(t, memberConfigs) + for _, cfg := range memberConfigs { + assertNoSecrets(t, "member list", cfg) + // Non-admin should also not see admin-only fields. + assert.Empty(t, cfg.OAuth2ClientID, "member should not see OAuth2ClientID") + assert.Empty(t, cfg.OAuth2AuthURL, "member should not see OAuth2AuthURL") + assert.Empty(t, cfg.OAuth2TokenURL, "member should not see OAuth2TokenURL") + assert.Empty(t, cfg.APIKeyHeader, "member should not see APIKeyHeader") + assert.Empty(t, cfg.OAuth2Scopes, "member should not see OAuth2Scopes") + assert.Empty(t, cfg.URL, "member should not see URL") + assert.Empty(t, cfg.Transport, "member should not see Transport") + } + + // Non-admin get-by-ID endpoint. + memberSingle, err := memberClient.MCPServerConfigByID(ctx, created.ID) + require.NoError(t, err) + assertNoSecrets(t, "member get-by-id", memberSingle) + assert.Empty(t, memberSingle.OAuth2ClientID, "member should not see OAuth2ClientID") + assert.Empty(t, memberSingle.OAuth2AuthURL, "member should not see OAuth2AuthURL") + assert.Empty(t, memberSingle.OAuth2TokenURL, "member should not see OAuth2TokenURL") + assert.Empty(t, memberSingle.OAuth2Scopes, "member should not see OAuth2Scopes") + assert.Empty(t, memberSingle.APIKeyHeader, "member should not see APIKeyHeader") + assert.Empty(t, memberSingle.URL, "member should not see URL") + assert.Empty(t, memberSingle.Transport, "member should not see Transport") +} + +func TestMCPServerConfigsAuthConnected(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newMCPClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient) + memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + // Create an oauth2 server config (enabled). + created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "OAuth Server", + Slug: "oauth-server", + Transport: "streamable_http", + URL: "https://mcp.example.com/oauth", + AuthType: "oauth2", + OAuth2ClientID: "cid", + OAuth2AuthURL: "https://auth.example.com/authorize", + OAuth2TokenURL: "https://auth.example.com/token", + Availability: "default_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + + // Regular user lists configs — auth_connected should be false + // because no token has been stored. + memberConfigs, err := memberClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, memberConfigs, 1) + require.Equal(t, created.ID, memberConfigs[0].ID) + require.False(t, memberConfigs[0].AuthConnected) + + // Also create a non-oauth server. It should report + // auth_connected=true because no auth is needed. + _ = createMCPServerConfig(t, adminClient, "no-auth-server", true) + memberConfigs, err = memberClient.MCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, memberConfigs, 2) + for _, cfg := range memberConfigs { + if cfg.AuthType == "none" { + require.True(t, cfg.AuthConnected) + } else { + require.False(t, cfg.AuthConnected) + } + } +} + +func TestMCPServerConfigsAvailability(t *testing.T) { + t.Parallel() + + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + validValues := []string{"force_on", "default_on", "default_off"} + for _, av := range validValues { + av := av + t.Run(av, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Server " + av, + Slug: "server-" + av, + Transport: "streamable_http", + URL: "https://mcp.example.com/" + av, + AuthType: "none", + Availability: av, + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + require.Equal(t, av, created.Availability) + }) + } + + t.Run("InvalidAvailability", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + _, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Bad Availability", + Slug: "bad-avail", + Transport: "streamable_http", + URL: "https://mcp.example.com/bad", + AuthType: "none", + Availability: "always_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) +} + +func TestMCPServerConfigsUniqueSlug(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + _, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "First", + Slug: "test-server", + Transport: "streamable_http", + URL: "https://mcp.example.com/first", + AuthType: "none", + Availability: "default_off", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + + // Attempt to create another config with the same slug. + _, err = client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "Second", + Slug: "test-server", + Transport: "streamable_http", + URL: "https://mcp.example.com/second", + AuthType: "none", + Availability: "default_off", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusConflict, sdkErr.StatusCode()) +} + +func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newMCPClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient) + memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{ + DisplayName: "OAuth Disconnect Test", + Slug: "oauth-disconnect", + Transport: "streamable_http", + URL: "https://mcp.example.com/oauth-disc", + AuthType: "oauth2", + OAuth2ClientID: "cid", + OAuth2AuthURL: "https://auth.example.com/authorize", + OAuth2TokenURL: "https://auth.example.com/token", + Availability: "default_on", + Enabled: true, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + }) + require.NoError(t, err) + + // Disconnect should succeed even when no token exists (idempotent). + err = memberClient.MCPServerOAuth2Disconnect(ctx, created.ID) + require.NoError(t, err) +} + +func TestChatWithMCPServerIDs(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newMCPClient(t) + _ = coderdtest.CreateFirstUser(t, client) + + // Create the chat model config required for creating a chat. + _ = createChatModelConfigForMCP(t, client) + + // Create an enabled MCP server config. + mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true) + + // Create a chat referencing the MCP server. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "hello with mcp server", + }, + }, + MCPServerIDs: []uuid.UUID{mcpConfig.ID}, + }) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, chat.ID) + require.Contains(t, chat.MCPServerIDs, mcpConfig.ID) + + // Fetch the chat and verify the MCP server IDs persist. + fetched, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID) +} + +// createChatModelConfigForMCP sets up a chat provider and model +// config so that CreateChat succeeds. This mirrors the helper in +// chats_test.go but is defined here to avoid coupling. +func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + contextLimit := int64(4096) + isDefault := true + modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "gpt-4o-mini", + ContextLimit: &contextLimit, + IsDefault: &isDefault, + }) + require.NoError(t, err) + return modelConfig +} diff --git a/coderd/workspaceproxies.go b/coderd/workspaceproxies.go index b8572cafc7..46f3fb0212 100644 --- a/coderd/workspaceproxies.go +++ b/coderd/workspaceproxies.go @@ -41,7 +41,7 @@ func (api *API) PrimaryRegion(ctx context.Context) (codersdk.Region, error) { ID: deploymentID, Name: "primary", DisplayName: proxy.DisplayName, - IconURL: proxy.IconUrl, + IconURL: proxy.IconURL, Healthy: true, PathAppURL: api.AccessURL.String(), WildcardHostname: appurl.SubdomainAppHost(api.AppHostname, api.AccessURL), diff --git a/codersdk/chats.go b/codersdk/chats.go index bdc6ecfbd9..dea346682f 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -49,6 +49,7 @@ type Chat struct { CreatedAt time.Time `json:"created_at" format:"date-time"` UpdatedAt time.Time `json:"updated_at" format:"date-time"` Archived bool `json:"archived"` + MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"` } // ChatMessage represents a single message in a chat. @@ -267,6 +268,7 @@ type CreateChatRequest struct { Content []ChatInputPart `json:"content"` WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` + MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` } // UpdateChatRequest is the request to update a chat. @@ -279,6 +281,7 @@ type UpdateChatRequest struct { type CreateChatMessageRequest struct { Content []ChatInputPart `json:"content"` ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` + MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` } // EditChatMessageRequest is the request to edit a user message in a chat. diff --git a/codersdk/mcp.go b/codersdk/mcp.go new file mode 100644 index 0000000000..081eb19a77 --- /dev/null +++ b/codersdk/mcp.go @@ -0,0 +1,191 @@ +package codersdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +// MCPServerOAuth2ConnectURL returns the URL the user should visit to +// start the OAuth2 flow for an MCP server. The frontend opens this +// in a new window/popup. +func (c *Client) MCPServerOAuth2ConnectURL(id uuid.UUID) string { + return fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/connect", c.URL.String(), id) +} + +// MCPServerOAuth2Disconnect removes the user's OAuth2 token for an +// MCP server. +func (c *Client) MCPServerOAuth2Disconnect(ctx context.Context, id uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/disconnect", id), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + +// MCPServerConfig represents an admin-configured MCP server. +type MCPServerConfig struct { + ID uuid.UUID `json:"id" format:"uuid"` + DisplayName string `json:"display_name"` + Slug string `json:"slug"` + Description string `json:"description"` + IconURL string `json:"icon_url"` + + Transport string `json:"transport"` // "streamable_http" or "sse" + URL string `json:"url"` + + AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers" + + // OAuth2 fields (only populated for admins). + OAuth2ClientID string `json:"oauth2_client_id,omitempty"` + HasOAuth2Secret bool `json:"has_oauth2_secret"` + OAuth2AuthURL string `json:"oauth2_auth_url,omitempty"` + OAuth2TokenURL string `json:"oauth2_token_url,omitempty"` + OAuth2Scopes string `json:"oauth2_scopes,omitempty"` + + // API key fields (only populated for admins). + APIKeyHeader string `json:"api_key_header,omitempty"` + HasAPIKey bool `json:"has_api_key"` + + HasCustomHeaders bool `json:"has_custom_headers"` + + // Tool governance. + ToolAllowList []string `json:"tool_allow_list"` + ToolDenyList []string `json:"tool_deny_list"` + + // Availability policy set by admin. + Availability string `json:"availability"` // "force_on", "default_on", "default_off" + + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` + + // Per-user state (populated for non-admin requests). + AuthConnected bool `json:"auth_connected"` +} + +// CreateMCPServerConfigRequest is the request to create a new MCP server config. +type CreateMCPServerConfigRequest struct { + DisplayName string `json:"display_name" validate:"required"` + Slug string `json:"slug" validate:"required"` + Description string `json:"description"` + IconURL string `json:"icon_url"` + + Transport string `json:"transport" validate:"required,oneof=streamable_http sse"` + URL string `json:"url" validate:"required,url"` + + AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers"` + OAuth2ClientID string `json:"oauth2_client_id,omitempty"` + OAuth2ClientSecret string `json:"oauth2_client_secret,omitempty"` + OAuth2AuthURL string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"` + OAuth2TokenURL string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"` + OAuth2Scopes string `json:"oauth2_scopes,omitempty"` + APIKeyHeader string `json:"api_key_header,omitempty"` + APIKeyValue string `json:"api_key_value,omitempty"` + CustomHeaders map[string]string `json:"custom_headers,omitempty"` + + ToolAllowList []string `json:"tool_allow_list,omitempty"` + ToolDenyList []string `json:"tool_deny_list,omitempty"` + + Availability string `json:"availability" validate:"required,oneof=force_on default_on default_off"` + Enabled bool `json:"enabled"` +} + +// UpdateMCPServerConfigRequest is the request to update an MCP server config. +type UpdateMCPServerConfigRequest struct { + DisplayName *string `json:"display_name,omitempty"` + Slug *string `json:"slug,omitempty"` + Description *string `json:"description,omitempty"` + IconURL *string `json:"icon_url,omitempty"` + + Transport *string `json:"transport,omitempty" validate:"omitempty,oneof=streamable_http sse"` + URL *string `json:"url,omitempty" validate:"omitempty,url"` + + AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers"` + OAuth2ClientID *string `json:"oauth2_client_id,omitempty"` + OAuth2ClientSecret *string `json:"oauth2_client_secret,omitempty"` + OAuth2AuthURL *string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"` + OAuth2TokenURL *string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"` + OAuth2Scopes *string `json:"oauth2_scopes,omitempty"` + APIKeyHeader *string `json:"api_key_header,omitempty"` + APIKeyValue *string `json:"api_key_value,omitempty"` + CustomHeaders *map[string]string `json:"custom_headers,omitempty"` + + ToolAllowList *[]string `json:"tool_allow_list,omitempty"` + ToolDenyList *[]string `json:"tool_deny_list,omitempty"` + + Availability *string `json:"availability,omitempty" validate:"omitempty,oneof=force_on default_on default_off"` + Enabled *bool `json:"enabled,omitempty"` +} + +func (c *Client) MCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/mcp/servers", nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + var configs []MCPServerConfig + return configs, json.NewDecoder(res.Body).Decode(&configs) +} + +func (c *Client) MCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil) + if err != nil { + return MCPServerConfig{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return MCPServerConfig{}, ReadBodyAsError(res) + } + var config MCPServerConfig + return config, json.NewDecoder(res.Body).Decode(&config) +} + +func (c *Client) CreateMCPServerConfig(ctx context.Context, req CreateMCPServerConfigRequest) (MCPServerConfig, error) { + res, err := c.Request(ctx, http.MethodPost, "/api/experimental/mcp/servers", req) + if err != nil { + return MCPServerConfig{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return MCPServerConfig{}, ReadBodyAsError(res) + } + var config MCPServerConfig + return config, json.NewDecoder(res.Body).Decode(&config) +} + +func (c *Client) UpdateMCPServerConfig(ctx context.Context, id uuid.UUID, req UpdateMCPServerConfigRequest) (MCPServerConfig, error) { + res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), req) + if err != nil { + return MCPServerConfig{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return MCPServerConfig{}, ReadBodyAsError(res) + } + var config MCPServerConfig + return config, json.NewDecoder(res.Body).Decode(&config) +} + +func (c *Client) DeleteMCPServerConfig(ctx context.Context, id uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} diff --git a/enterprise/coderd/workspaceproxy.go b/enterprise/coderd/workspaceproxy.go index 2832707dc8..43486da8f5 100644 --- a/enterprise/coderd/workspaceproxy.go +++ b/enterprise/coderd/workspaceproxy.go @@ -204,7 +204,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw args := database.UpsertDefaultProxyParams{ DisplayName: req.DisplayName, - IconUrl: req.Icon, + IconURL: req.Icon, } if req.DisplayName == "" || req.Icon == "" { // If the user has not specified an update value, use the existing value. @@ -217,7 +217,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw args.DisplayName = existing.DisplayName } if req.Icon == "" { - args.IconUrl = existing.IconUrl + args.IconURL = existing.IconURL } } diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 0dcf8c928a..3c5f957786 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -471,6 +471,201 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat return provider, nil } +// decryptMCPServerConfig decrypts all encrypted fields on a +// single MCPServerConfig in place. +func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error { + if err := db.decryptField(&cfg.OAuth2ClientSecret, cfg.OAuth2ClientSecretKeyID); err != nil { + return err + } + if err := db.decryptField(&cfg.APIKeyValue, cfg.APIKeyValueKeyID); err != nil { + return err + } + return db.decryptField(&cfg.CustomHeaders, cfg.CustomHeadersKeyID) +} + +// decryptMCPServerUserToken decrypts all encrypted fields on a +// single MCPServerUserToken in place. +func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) error { + if err := db.decryptField(&tok.AccessToken, tok.AccessTokenKeyID); err != nil { + return err + } + return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID) +} + +func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { + cfg, err := db.Store.GetMCPServerConfigByID(ctx, id) + if err != nil { + return database.MCPServerConfig{}, err + } + if err := db.decryptMCPServerConfig(&cfg); err != nil { + return database.MCPServerConfig{}, err + } + return cfg, nil +} + +func (db *dbCrypt) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) { + cfg, err := db.Store.GetMCPServerConfigBySlug(ctx, slug) + if err != nil { + return database.MCPServerConfig{}, err + } + if err := db.decryptMCPServerConfig(&cfg); err != nil { + return database.MCPServerConfig{}, err + } + return cfg, nil +} + +func (db *dbCrypt) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + cfgs, err := db.Store.GetMCPServerConfigs(ctx) + if err != nil { + return nil, err + } + for i := range cfgs { + if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil { + return nil, err + } + } + return cfgs, nil +} + +func (db *dbCrypt) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) { + cfgs, err := db.Store.GetMCPServerConfigsByIDs(ctx, ids) + if err != nil { + return nil, err + } + for i := range cfgs { + if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil { + return nil, err + } + } + return cfgs, nil +} + +func (db *dbCrypt) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + cfgs, err := db.Store.GetEnabledMCPServerConfigs(ctx) + if err != nil { + return nil, err + } + for i := range cfgs { + if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil { + return nil, err + } + } + return cfgs, nil +} + +func (db *dbCrypt) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { + cfgs, err := db.Store.GetForcedMCPServerConfigs(ctx) + if err != nil { + return nil, err + } + for i := range cfgs { + if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil { + return nil, err + } + } + return cfgs, nil +} + +func (db *dbCrypt) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + tok, err := db.Store.GetMCPServerUserToken(ctx, arg) + if err != nil { + return database.MCPServerUserToken{}, err + } + if err := db.decryptMCPServerUserToken(&tok); err != nil { + return database.MCPServerUserToken{}, err + } + return tok, nil +} + +func (db *dbCrypt) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) { + toks, err := db.Store.GetMCPServerUserTokensByUserID(ctx, userID) + if err != nil { + return nil, err + } + for i := range toks { + if err := db.decryptMCPServerUserToken(&toks[i]); err != nil { + return nil, err + } + } + return toks, nil +} + +func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) { + if strings.TrimSpace(params.OAuth2ClientSecret) == "" { + params.OAuth2ClientSecretKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil { + return database.MCPServerConfig{}, err + } + if strings.TrimSpace(params.APIKeyValue) == "" { + params.APIKeyValueKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil { + return database.MCPServerConfig{}, err + } + if strings.TrimSpace(params.CustomHeaders) == "" { + params.CustomHeadersKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil { + return database.MCPServerConfig{}, err + } + + cfg, err := db.Store.InsertMCPServerConfig(ctx, params) + if err != nil { + return database.MCPServerConfig{}, err + } + if err := db.decryptMCPServerConfig(&cfg); err != nil { + return database.MCPServerConfig{}, err + } + return cfg, nil +} + +func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) { + if strings.TrimSpace(params.OAuth2ClientSecret) == "" { + params.OAuth2ClientSecretKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil { + return database.MCPServerConfig{}, err + } + if strings.TrimSpace(params.APIKeyValue) == "" { + params.APIKeyValueKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil { + return database.MCPServerConfig{}, err + } + if strings.TrimSpace(params.CustomHeaders) == "" { + params.CustomHeadersKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil { + return database.MCPServerConfig{}, err + } + + cfg, err := db.Store.UpdateMCPServerConfig(ctx, params) + if err != nil { + return database.MCPServerConfig{}, err + } + if err := db.decryptMCPServerConfig(&cfg); err != nil { + return database.MCPServerConfig{}, err + } + return cfg, nil +} + +func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { + if strings.TrimSpace(params.AccessToken) == "" { + params.AccessTokenKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.AccessToken, ¶ms.AccessTokenKeyID); err != nil { + return database.MCPServerUserToken{}, err + } + if strings.TrimSpace(params.RefreshToken) == "" { + params.RefreshTokenKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.RefreshToken, ¶ms.RefreshTokenKeyID); err != nil { + return database.MCPServerUserToken{}, err + } + + tok, err := db.Store.UpsertMCPServerUserToken(ctx, params) + if err != nil { + return database.MCPServerUserToken{}, err + } + if err := db.decryptMCPServerUserToken(&tok); err != nil { + return database.MCPServerUserToken{}, err + } + return tok, nil +} + func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error { // If no cipher is loaded, then we can't encrypt anything! if db.ciphers == nil || db.primaryCipherDigest == "" { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index fcf9eae2de..d664987a56 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/lib/pq" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -878,3 +879,301 @@ func fakeBase64RandomData(t *testing.T, n int) string { require.NoError(t, err) return base64.StdEncoding.EncodeToString(b) } + +// requireMCPServerConfigDecrypted verifies all encrypted fields on an +// MCPServerConfig match the expected plaintext values and carry the +// correct key-ID. +func requireMCPServerConfigDecrypted( + t *testing.T, + cfg database.MCPServerConfig, + ciphers []Cipher, + wantSecret, wantAPIKey, wantHeaders string, +) { + t.Helper() + require.Equal(t, wantSecret, cfg.OAuth2ClientSecret) + require.Equal(t, wantAPIKey, cfg.APIKeyValue) + require.Equal(t, wantHeaders, cfg.CustomHeaders) + require.Equal(t, ciphers[0].HexDigest(), cfg.OAuth2ClientSecretKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), cfg.APIKeyValueKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), cfg.CustomHeadersKeyID.String) +} + +// requireMCPServerConfigRawEncrypted reads the config from the raw +// (unwrapped) store and asserts every secret field is encrypted. +func requireMCPServerConfigRawEncrypted( + ctx context.Context, + t *testing.T, + rawDB database.Store, + cfgID uuid.UUID, + ciphers []Cipher, + wantSecret, wantAPIKey, wantHeaders string, +) { + t.Helper() + raw, err := rawDB.GetMCPServerConfigByID(ctx, cfgID) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], raw.OAuth2ClientSecret, wantSecret) + requireEncryptedEquals(t, ciphers[0], raw.APIKeyValue, wantAPIKey) + requireEncryptedEquals(t, ciphers[0], raw.CustomHeaders, wantHeaders) +} + +func TestMCPServerConfigs(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const ( + //nolint:gosec // test credentials + oauthSecret = "my-oauth-secret" + apiKeyValue = "my-api-key" + customHeaders = `{"X-Custom":"header-value"}` + ) + // insertConfig is a small helper that creates a user and an MCP + // server config through the encrypted store, returning both. + insertConfig := func(t *testing.T, crypt *dbCrypt, ciphers []Cipher) database.MCPServerConfig { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{ + DisplayName: "Test MCP Server", + Slug: "test-mcp-" + uuid.New().String()[:8], + Description: "test description", + Url: "https://mcp.example.com", + Transport: "streamable_http", + AuthType: "oauth2", + OAuth2ClientID: "client-id", + OAuth2ClientSecret: oauthSecret, + APIKeyValue: apiKeyValue, + CustomHeaders: customHeaders, + ToolAllowList: []string{}, + ToolDenyList: []string{}, + Availability: "force_on", + Enabled: true, + CreatedBy: user.ID, + UpdatedBy: user.ID, + }) + require.NoError(t, err) + requireMCPServerConfigDecrypted(t, cfg, ciphers, oauthSecret, apiKeyValue, customHeaders) + return cfg + } + + t.Run("InsertMCPServerConfig", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetMCPServerConfigByID", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + got, err := crypt.GetMCPServerConfigByID(ctx, cfg.ID) + require.NoError(t, err) + requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetMCPServerConfigBySlug", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + got, err := crypt.GetMCPServerConfigBySlug(ctx, cfg.Slug) + require.NoError(t, err) + requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetMCPServerConfigs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + cfgs, err := crypt.GetMCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, cfgs, 1) + requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetMCPServerConfigsByIDs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + cfgs, err := crypt.GetMCPServerConfigsByIDs(ctx, []uuid.UUID{cfg.ID}) + require.NoError(t, err) + require.Len(t, cfgs, 1) + requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetEnabledMCPServerConfigs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + cfgs, err := crypt.GetEnabledMCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, cfgs, 1) + requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("GetForcedMCPServerConfigs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + cfgs, err := crypt.GetForcedMCPServerConfigs(ctx) + require.NoError(t, err) + require.Len(t, cfgs, 1) + requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders) + }) + + t.Run("UpdateMCPServerConfig", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg := insertConfig(t, crypt, ciphers) + + const ( + //nolint:gosec // test credential + newSecret = "updated-oauth-secret" + newAPIKey = "updated-api-key" + newHeaders = `{"X-New":"new-value"}` + ) + updated, err := crypt.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{ + ID: cfg.ID, + DisplayName: cfg.DisplayName, + Slug: cfg.Slug, + Description: cfg.Description, + Url: cfg.Url, + Transport: cfg.Transport, + AuthType: cfg.AuthType, + OAuth2ClientID: cfg.OAuth2ClientID, + OAuth2ClientSecret: newSecret, + APIKeyValue: newAPIKey, + CustomHeaders: newHeaders, + ToolAllowList: cfg.ToolAllowList, + ToolDenyList: cfg.ToolDenyList, + Availability: cfg.Availability, + Enabled: cfg.Enabled, + UpdatedBy: cfg.CreatedBy.UUID, + }) + require.NoError(t, err) + requireMCPServerConfigDecrypted(t, updated, ciphers, newSecret, newAPIKey, newHeaders) + requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, newSecret, newAPIKey, newHeaders) + }) +} + +func TestMCPServerUserTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const ( + accessToken = "access-token-value" + refreshToken = "refresh-token-value" + ) + + // insertConfigAndToken creates a user, an MCP server config, and a + // user token through the encrypted store. + insertConfigAndToken := func( + t *testing.T, + crypt *dbCrypt, + ciphers []Cipher, + ) (database.MCPServerConfig, database.MCPServerUserToken) { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{ + DisplayName: "Token Test MCP", + Slug: "tok-mcp-" + uuid.New().String()[:8], + Url: "https://mcp.example.com", + Transport: "streamable_http", + AuthType: "oauth2", + ToolAllowList: []string{}, + ToolDenyList: []string{}, + Availability: "default_off", + Enabled: true, + CreatedBy: user.ID, + UpdatedBy: user.ID, + }) + require.NoError(t, err) + + tok, err := crypt.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{ + MCPServerConfigID: cfg.ID, + UserID: user.ID, + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: "Bearer", + }) + require.NoError(t, err) + require.Equal(t, accessToken, tok.AccessToken) + require.Equal(t, refreshToken, tok.RefreshToken) + require.Equal(t, ciphers[0].HexDigest(), tok.AccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), tok.RefreshTokenKeyID.String) + return cfg, tok + } + + t.Run("UpsertMCPServerUserToken", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, tok := insertConfigAndToken(t, crypt, ciphers) + + // Verify the raw DB values are encrypted. + rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{ + MCPServerConfigID: cfg.ID, + UserID: tok.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken) + requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken) + }) + + t.Run("GetMCPServerUserToken", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, tok := insertConfigAndToken(t, crypt, ciphers) + + got, err := crypt.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{ + MCPServerConfigID: cfg.ID, + UserID: tok.UserID, + }) + require.NoError(t, err) + require.Equal(t, accessToken, got.AccessToken) + require.Equal(t, refreshToken, got.RefreshToken) + require.Equal(t, ciphers[0].HexDigest(), got.AccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), got.RefreshTokenKeyID.String) + + // Raw values must be encrypted. + rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{ + MCPServerConfigID: cfg.ID, + UserID: tok.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken) + requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken) + }) + + t.Run("GetMCPServerUserTokensByUserID", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, tok := insertConfigAndToken(t, crypt, ciphers) + + toks, err := crypt.GetMCPServerUserTokensByUserID(ctx, tok.UserID) + require.NoError(t, err) + require.Len(t, toks, 1) + require.Equal(t, accessToken, toks[0].AccessToken) + require.Equal(t, refreshToken, toks[0].RefreshToken) + require.Equal(t, ciphers[0].HexDigest(), toks[0].AccessTokenKeyID.String) + require.Equal(t, ciphers[0].HexDigest(), toks[0].RefreshTokenKeyID.String) + + // Raw values must be encrypted. + rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{ + MCPServerConfigID: cfg.ID, + UserID: tok.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken) + requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken) + }) +} diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index d51a2909d2..8b4ac90f76 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -73,6 +73,7 @@ const makeChat = ( id, owner_id: "owner-1", last_model_config_id: "model-1", + mcp_server_ids: [], title: `Chat ${id}`, status: "running", created_at: "2025-01-01T00:00:00.000Z", diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 584da94d1a..d988e464eb 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1070,6 +1070,7 @@ export interface Chat { readonly created_at: string; readonly updated_at: string; readonly archived: boolean; + readonly mcp_server_ids: readonly string[]; } // From codersdk/deployment.go @@ -2076,6 +2077,7 @@ export interface ConvertLoginRequest { export interface CreateChatMessageRequest { readonly content: readonly ChatInputPart[]; readonly model_config_id?: string; + readonly mcp_server_ids?: string[]; } // From codersdk/chats.go @@ -2123,6 +2125,7 @@ export interface CreateChatRequest { readonly content: readonly ChatInputPart[]; readonly workspace_id?: string; readonly model_config_id?: string; + readonly mcp_server_ids?: readonly string[]; } // From codersdk/users.go @@ -2163,6 +2166,32 @@ export interface CreateGroupRequest { readonly quota_allowance: number; } +// From codersdk/mcp.go +/** + * CreateMCPServerConfigRequest is the request to create a new MCP server config. + */ +export interface CreateMCPServerConfigRequest { + readonly display_name: string; + readonly slug: string; + readonly description: string; + readonly icon_url: string; + readonly transport: string; + readonly url: string; + readonly auth_type: string; + readonly oauth2_client_id?: string; + readonly oauth2_client_secret?: string; + readonly oauth2_auth_url?: string; + readonly oauth2_token_url?: string; + readonly oauth2_scopes?: string; + readonly api_key_header?: string; + readonly api_key_value?: string; + readonly custom_headers?: Record; + readonly tool_allow_list?: readonly string[]; + readonly tool_deny_list?: readonly string[]; + readonly availability: string; + readonly enabled: boolean; +} + // From codersdk/organizations.go export interface CreateOrganizationRequest { readonly name: string; @@ -3692,6 +3721,51 @@ export interface LoginWithPasswordResponse { readonly session_token: string; } +// From codersdk/mcp.go +/** + * MCPServerConfig represents an admin-configured MCP server. + */ +export interface MCPServerConfig { + readonly id: string; + readonly display_name: string; + readonly slug: string; + readonly description: string; + readonly icon_url: string; + readonly transport: string; // "streamable_http" or "sse" + readonly url: string; + readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers" + /** + * OAuth2 fields (only populated for admins). + */ + readonly oauth2_client_id?: string; + readonly has_oauth2_secret: boolean; + readonly oauth2_auth_url?: string; + readonly oauth2_token_url?: string; + readonly oauth2_scopes?: string; + /** + * API key fields (only populated for admins). + */ + readonly api_key_header?: string; + readonly has_api_key: boolean; + readonly has_custom_headers: boolean; + /** + * Tool governance. + */ + readonly tool_allow_list: readonly string[]; + readonly tool_deny_list: readonly string[]; + /** + * Availability policy set by admin. + */ + readonly availability: string; // "force_on", "default_on", "default_off" + readonly enabled: boolean; + readonly created_at: string; + readonly updated_at: string; + /** + * Per-user state (populated for non-admin requests). + */ + readonly auth_connected: boolean; +} + // From codersdk/provisionerdaemons.go /** * MatchedProvisioners represents the number of provisioner daemons @@ -6844,6 +6918,32 @@ export interface UpdateInboxNotificationReadStatusResponse { readonly unread_count: number; } +// From codersdk/mcp.go +/** + * UpdateMCPServerConfigRequest is the request to update an MCP server config. + */ +export interface UpdateMCPServerConfigRequest { + readonly display_name?: string; + readonly slug?: string; + readonly description?: string; + readonly icon_url?: string; + readonly transport?: string; + readonly url?: string; + readonly auth_type?: string; + readonly oauth2_client_id?: string; + readonly oauth2_client_secret?: string; + readonly oauth2_auth_url?: string; + readonly oauth2_token_url?: string; + readonly oauth2_scopes?: string; + readonly api_key_header?: string; + readonly api_key_value?: string; + readonly custom_headers?: Record; + readonly tool_allow_list?: string[]; + readonly tool_deny_list?: string[]; + readonly availability?: string; + readonly enabled?: boolean; +} + // From codersdk/notifications.go export interface UpdateNotificationTemplateMethod { readonly method?: string; diff --git a/site/src/pages/AgentsPage/AgentDetail.stories.tsx b/site/src/pages/AgentsPage/AgentDetail.stories.tsx index 2a556eb0b1..3b4118f881 100644 --- a/site/src/pages/AgentsPage/AgentDetail.stories.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.stories.tsx @@ -106,6 +106,7 @@ const baseChatFields = { owner_id: "owner-id", workspace_id: mockWorkspace.id, last_model_config_id: "model-config-1", + mcp_server_ids: [], created_at: "2026-02-18T00:00:00.000Z", updated_at: "2026-02-18T00:00:00.000Z", archived: false, diff --git a/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx b/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx index 06c6705521..c429ff5e8e 100644 --- a/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx +++ b/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx @@ -183,6 +183,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({ id: chatID, owner_id: "owner-1", last_model_config_id: "model-1", + mcp_server_ids: [], title: "test", status: "running", created_at: "2025-01-01T00:00:00.000Z", diff --git a/site/src/pages/AgentsPage/AgentDetail/TopBar.stories.tsx b/site/src/pages/AgentsPage/AgentDetail/TopBar.stories.tsx index e794833fa5..2d0aa50819 100644 --- a/site/src/pages/AgentsPage/AgentDetail/TopBar.stories.tsx +++ b/site/src/pages/AgentsPage/AgentDetail/TopBar.stories.tsx @@ -52,6 +52,7 @@ export const WithParentChat: Story = { id: "parent-chat-1", owner_id: "owner-id", last_model_config_id: "model-config-1", + mcp_server_ids: [], title: "Set up CI/CD pipeline", status: "completed", last_error: null, diff --git a/site/src/pages/AgentsPage/AgentDetailView.stories.tsx b/site/src/pages/AgentsPage/AgentDetailView.stories.tsx index 19dce7a02e..db32338c48 100644 --- a/site/src/pages/AgentsPage/AgentDetailView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentDetailView.stories.tsx @@ -36,6 +36,7 @@ const buildChat = (overrides: Partial = {}): TypesGen.Chat => ({ title: "Help me refactor", status: "completed", last_model_config_id: "model-config-1", + mcp_server_ids: [], created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx index 891101d675..d87e0a5116 100644 --- a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx @@ -113,6 +113,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ title: "Agent", status: "completed", last_model_config_id: defaultModelConfigs[0].id, + mcp_server_ids: [], created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/AgentsSidebar.stories.tsx b/site/src/pages/AgentsPage/AgentsSidebar.stories.tsx index ae59fdc9ac..c2c4e8fdb9 100644 --- a/site/src/pages/AgentsPage/AgentsSidebar.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsSidebar.stories.tsx @@ -40,6 +40,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ title: "Agent", status: "completed", last_model_config_id: defaultModelConfigs[0].id, + mcp_server_ids: [], created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/AgentsSidebar.test.tsx b/site/src/pages/AgentsPage/AgentsSidebar.test.tsx index 37fe353c50..92f21f0136 100644 --- a/site/src/pages/AgentsPage/AgentsSidebar.test.tsx +++ b/site/src/pages/AgentsPage/AgentsSidebar.test.tsx @@ -63,6 +63,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ updated_at: oneWeekAgo, archived: false, last_error: null, + mcp_server_ids: [], ...overrides, });