mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add MCP server configuration backend for chats (#23227)
## Summary
Adds the database schema, API endpoints, SDK types, and encryption
wrappers for admin-managed MCP (Model Context Protocol) server
configurations that chatd can consume. This is the backend foundation
for allowing external MCP tools (Sentry, Linear, GitHub, etc.) to be
used during AI chat sessions.
## Database
Two new tables:
- **`mcp_server_configs`**: Admin-managed server definitions with URL,
transport (Streamable HTTP / SSE), auth config (none / OAuth2 / API key
/ custom headers), tool allow/deny lists, and an availability policy
(`force_on` / `default_on` / `default_off`). Includes CHECK constraints
on transport, auth_type, and availability values.
- **`mcp_server_user_tokens`**: Per-user OAuth2 tokens for servers
requiring individual authentication. Cascades on user/config deletion.
New column on `chats` table:
- **`mcp_server_ids UUID[]`**: Per-chat MCP server selection, following
the same pattern as `model_config_id` — passed at chat creation,
changeable per-message with nil-means-no-change semantics.
## API Endpoints
All routes are under `/api/experimental/mcp/servers/` and gated behind
the `agents` experiment.
**Admin endpoints** (`ResourceDeploymentConfig` auth):
- `POST /` — Create MCP server config
- `PATCH /{id}` — Update MCP server config (full-replace)
- `DELETE /{id}` — Delete MCP server config
**Authenticated endpoints** (all users, enabled servers only for
non-admins):
- `GET /` — List configs (admins see all, members see enabled-only with
admin fields redacted)
- `GET /{id}` — Get config by ID (with `auth_connected` populated
per-user)
**OAuth2 per-user auth flow:**
- `GET /{id}/oauth2/connect` — Initiate OAuth2 flow (state cookie CSRF
protection)
- `GET /{id}/oauth2/callback` — Handle OAuth2 callback, store tokens
- `DELETE /{id}/oauth2/disconnect` — Remove stored OAuth2 tokens
## Security
- **Secrets never returned**: `OAuth2ClientSecret`, `APIKeyValue`, and
`CustomHeaders` are never in API responses — only boolean indicators
(`has_oauth2_secret`, `has_api_key`, `has_custom_headers`).
- **Field redaction for non-admins**: `convertMCPServerConfigRedacted`
strips `OAuth2ClientID`, auth URLs, scopes, and `APIKeyHeader` from
non-admin responses.
- **dbcrypt encryption at rest**: All 5 secret fields use `dbcrypt_keys`
encryption with full encrypt-on-write / decrypt-on-read wrappers (11
dbcrypt method overrides + 2 helpers), following the same pattern as
`chat_providers.api_key`.
- **OAuth2 CSRF protection**: State parameter stored in `HttpOnly`
cookie with `HTTPCookies.Apply()` for correct `Secure`/`SameSite` behind
TLS-terminating proxies.
- **dbauthz authorization**: All 18 querier methods have authorization
wrappers. Read operations use `ActionRead`, write operations use
`ActionUpdate` on `ResourceDeploymentConfig`.
## Governance Model
| Control | Implementation |
|---------|---------------|
| **Global kill switch** | `enabled` defaults to `false` |
| **Availability policy** | `force_on` (always injected), `default_on`
(pre-selected), `default_off` (opt-in) |
| **Per-chat selection** | `mcp_server_ids` on `CreateChatRequest` /
`CreateChatMessageRequest` |
| **Auth gate** | OAuth2 servers require per-user auth before tools are
injected |
| **Tool-level allow/deny** | Arrays on `mcp_server_configs` for
granular tool filtering |
| **Secrets encrypted at rest** | Uses `dbcrypt_keys` (same pattern as
`chat_providers.api_key`) |
## Tests
8 test functions covering:
- Full CRUD lifecycle (create, list, update, delete)
- Non-admin visibility filtering (enabled-only, field redaction)
- `auth_connected` population for OAuth2 vs non-OAuth2 servers
- Availability policy validation (valid values + invalid rejection)
- Unique slug enforcement (409 Conflict)
- OAuth2 disconnect idempotency
- Chat creation with `mcp_server_ids` persistence
## Known Limitations (Deferred)
These are documented and intentional for an experimental feature:
- **Audit logging** not yet wired — will add when feature stabilizes
- **Cross-field validation** (e.g., OAuth2 fields required when
`auth_type=oauth2`) — admin-only endpoint, will add when stabilizing
- **`force_on` auto-injection** — query exists but not yet wired into
chatd tool injection (follow-up)
- **Additional test coverage** — 403 auth tests, GET-by-ID tests,
callback CSRF tests planned for follow-up
## What's NOT in this PR
- Frontend UI (admin panel + chat picker)
- Actual MCP client connections (`chatd/chatmcp/` manager)
- Tool injection into `chatloop/`
This commit is contained in:
@@ -379,6 +379,7 @@ type CreateOptions struct {
|
||||
ChatMode database.NullChatMode
|
||||
SystemPrompt string
|
||||
InitialUserContent []codersdk.ChatMessagePart
|
||||
MCPServerIDs []uuid.UUID
|
||||
}
|
||||
|
||||
// SendMessageBusyBehavior controls what happens when a chat is already active.
|
||||
@@ -401,6 +402,7 @@ type SendMessageOptions struct {
|
||||
Content []codersdk.ChatMessagePart
|
||||
ModelConfigID *uuid.UUID
|
||||
BusyBehavior SendMessageBusyBehavior
|
||||
MCPServerIDs *[]uuid.UUID
|
||||
}
|
||||
|
||||
// SendMessageResult contains the outcome of user message processing.
|
||||
@@ -450,6 +452,12 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
if len(opts.InitialUserContent) == 0 {
|
||||
return database.Chat{}, xerrors.New("initial user content is required")
|
||||
}
|
||||
// Ensure MCPServerIDs is non-nil so pq.Array produces '{}'
|
||||
// instead of SQL NULL, which violates the NOT NULL column
|
||||
// constraint.
|
||||
if opts.MCPServerIDs == nil {
|
||||
opts.MCPServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
var chat database.Chat
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
@@ -465,6 +473,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
LastModelConfigID: opts.ModelConfigID,
|
||||
Title: opts.Title,
|
||||
Mode: opts.ChatMode,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert chat: %w", err)
|
||||
@@ -596,6 +605,17 @@ func (p *Server) SendMessage(
|
||||
modelConfigID = *opts.ModelConfigID
|
||||
}
|
||||
|
||||
// Update MCP server IDs on the chat when explicitly provided.
|
||||
if opts.MCPServerIDs != nil {
|
||||
lockedChat, err = tx.UpdateChatMCPServerIDs(ctx, database.UpdateChatMCPServerIDsParams{
|
||||
ID: opts.ChatID,
|
||||
MCPServerIDs: *opts.MCPServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update chat mcp server ids: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
existingQueued, err := tx.GetChatQueuedMessages(ctx, opts.ChatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get queued messages: %w", err)
|
||||
|
||||
@@ -284,6 +284,41 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate MCP server IDs exist.
|
||||
if len(req.MCPServerIDs) > 0 {
|
||||
//nolint:gocritic // Need to validate MCP server IDs exist.
|
||||
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to validate MCP server IDs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(existingConfigs) != len(req.MCPServerIDs) {
|
||||
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
||||
for _, c := range existingConfigs {
|
||||
found[c.ID] = struct{}{}
|
||||
}
|
||||
var missing []string
|
||||
for _, id := range req.MCPServerIDs {
|
||||
if _, ok := found[id]; !ok {
|
||||
missing = append(missing, id.String())
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "One or more MCP server IDs are invalid.",
|
||||
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mcpServerIDs := req.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -291,6 +326,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -1456,6 +1492,36 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate MCP server IDs exist.
|
||||
if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 {
|
||||
//nolint:gocritic // Need to validate MCP server IDs exist.
|
||||
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to validate MCP server IDs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(existingConfigs) != len(*req.MCPServerIDs) {
|
||||
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
||||
for _, c := range existingConfigs {
|
||||
found[c.ID] = struct{}{}
|
||||
}
|
||||
var missing []string
|
||||
for _, id := range *req.MCPServerIDs {
|
||||
if _, ok := found[id]; !ok {
|
||||
missing = append(missing, id.String())
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "One or more MCP server IDs are invalid.",
|
||||
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
@@ -1464,6 +1530,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
MCPServerIDs: req.MCPServerIDs,
|
||||
},
|
||||
)
|
||||
if sendErr != nil {
|
||||
@@ -2979,6 +3046,10 @@ func truncateRunes(value string, maxLen int) string {
|
||||
}
|
||||
|
||||
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
chat := codersdk.Chat{
|
||||
ID: c.ID,
|
||||
OwnerID: c.OwnerID,
|
||||
@@ -2988,6 +3059,7 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
|
||||
Archived: c.Archived,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
}
|
||||
if c.LastError.Valid {
|
||||
chat.LastError = &c.LastError.String
|
||||
|
||||
+19
-2
@@ -1233,10 +1233,27 @@ func New(options *Options) *API {
|
||||
r.Route("/mcp", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
|
||||
)
|
||||
// MCP server configuration endpoints.
|
||||
r.Route("/servers", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents))
|
||||
r.Get("/", api.listMCPServerConfigs)
|
||||
r.Post("/", api.createMCPServerConfig)
|
||||
r.Route("/{mcpServer}", func(r chi.Router) {
|
||||
r.Get("/", api.getMCPServerConfig)
|
||||
r.Patch("/", api.updateMCPServerConfig)
|
||||
r.Delete("/", api.deleteMCPServerConfig)
|
||||
// OAuth2 user flow
|
||||
r.Get("/oauth2/connect", api.mcpServerOAuth2Connect)
|
||||
r.Get("/oauth2/callback", api.mcpServerOAuth2Callback)
|
||||
r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect)
|
||||
})
|
||||
})
|
||||
// MCP HTTP transport endpoint with mandatory authentication
|
||||
r.Mount("/http", api.mcpHTTPHandler())
|
||||
r.Route("/http", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP))
|
||||
r.Mount("/", api.mcpHTTPHandler())
|
||||
})
|
||||
})
|
||||
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
|
||||
r.Use(
|
||||
|
||||
@@ -20,6 +20,9 @@ const (
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs
|
||||
CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs
|
||||
CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
|
||||
@@ -1691,6 +1691,13 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return q.db.CleanTailnetTunnels(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -1920,6 +1927,20 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteMCPServerConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil {
|
||||
return err
|
||||
@@ -2763,6 +2784,13 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
|
||||
}
|
||||
@@ -2821,6 +2849,13 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetForcedMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
|
||||
}
|
||||
@@ -2961,6 +2996,48 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
|
||||
return q.db.GetLogoURL(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigBySlug(ctx, slug)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return q.db.GetMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
|
||||
return nil, err
|
||||
@@ -4727,6 +4804,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
|
||||
return q.db.InsertLicense(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.InsertMCPServerConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
|
||||
return database.WorkspaceAgentMemoryResourceMonitor{}, err
|
||||
@@ -5486,6 +5570,17 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatMCPServerIDs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
// Authorize update on the parent chat of the edited message.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
|
||||
@@ -5650,6 +5745,13 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da
|
||||
return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.UpdateMCPServerConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
|
||||
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
@@ -6695,6 +6797,13 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
|
||||
return q.db.UpsertLogoURL(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return q.db.UpsertMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1005,6 +1005,114 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.DeleteMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
|
||||
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
slug := "test-mcp-server"
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug})
|
||||
dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes()
|
||||
check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
ids := []uuid.UUID{configA.ID, configB.ID}
|
||||
dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
|
||||
dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token)
|
||||
}))
|
||||
s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})}
|
||||
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
|
||||
}))
|
||||
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertMCPServerConfigParams{
|
||||
DisplayName: "Test MCP Server",
|
||||
Slug: "test-mcp-server",
|
||||
}
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug})
|
||||
dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatMCPServerIDsParams{
|
||||
ID: chat.ID,
|
||||
MCPServerIDs: []uuid.UUID{uuid.New()},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
arg := database.UpdateMCPServerConfigParams{
|
||||
ID: config.ID,
|
||||
DisplayName: "Updated MCP Server",
|
||||
Slug: "updated-mcp-server",
|
||||
}
|
||||
dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "bearer",
|
||||
}
|
||||
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
|
||||
dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
@@ -1381,8 +1489,8 @@ func (s *MethodTestSuite) TestLicense() {
|
||||
check.Args().Asserts().Returns("value")
|
||||
}))
|
||||
s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"})
|
||||
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"})
|
||||
}))
|
||||
s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes()
|
||||
|
||||
@@ -264,6 +264,14 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
m.queryLatencies.WithLabelValues("CleanupDeletedMCPServerIDsFromChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CleanupDeletedMCPServerIDsFromChats").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -480,6 +488,22 @@ func (m queryMetricsStore) DeleteLicense(ctx context.Context, id int32) (int32,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteMCPServerConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteMCPServerConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerConfigByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserToken").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOAuth2ProviderAppByClientID(ctx, id)
|
||||
@@ -1328,6 +1352,14 @@ func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]datab
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
|
||||
@@ -1384,6 +1416,14 @@ func (m queryMetricsStore) GetFilteredInboxNotificationsByUserID(ctx context.Con
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetForcedMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetForcedMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetForcedMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetGitSSHKey(ctx, userID)
|
||||
@@ -1544,6 +1584,54 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) {
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigBySlug(ctx, slug)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigBySlug").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigBySlug").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigsByIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigsByIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigsByIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserToken").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerUserTokensByUserID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserTokensByUserID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg)
|
||||
@@ -3184,6 +3272,14 @@ func (m queryMetricsStore) InsertLicense(ctx context.Context, arg database.Inser
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertMCPServerConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertMCPServerConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertMCPServerConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertMemoryResourceMonitor(ctx, arg)
|
||||
@@ -3856,6 +3952,14 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatMCPServerIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMCPServerIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
|
||||
@@ -3960,6 +4064,14 @@ func (m queryMetricsStore) UpdateInboxNotificationReadStatus(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateMCPServerConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateMCPServerConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateMCPServerConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateMemberRoles(ctx, arg)
|
||||
@@ -4696,6 +4808,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserToken").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertNotificationReportGeneratorLog(ctx, arg)
|
||||
|
||||
@@ -334,6 +334,20 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
|
||||
}
|
||||
|
||||
// CleanupDeletedMCPServerIDsFromChats mocks base method.
|
||||
func (m *MockStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupDeletedMCPServerIDsFromChats", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupDeletedMCPServerIDsFromChats indicates an expected call of CleanupDeletedMCPServerIDsFromChats.
|
||||
func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -769,6 +783,34 @@ func (mr *MockStoreMockRecorder) DeleteLicense(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLicense", reflect.TypeOf((*MockStore)(nil).DeleteLicense), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteMCPServerConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteMCPServerConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteMCPServerConfigByID indicates an expected call of DeleteMCPServerConfigByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteMCPServerUserToken indicates an expected call of DeleteMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) DeleteMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOAuth2ProviderAppByClientID mocks base method.
|
||||
func (m *MockStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2437,6 +2479,21 @@ func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledMCPServerConfigs indicates an expected call of GetEnabledMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetExternalAuthLink mocks base method.
|
||||
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2542,6 +2599,21 @@ func (mr *MockStoreMockRecorder) GetFilteredInboxNotificationsByUserID(ctx, arg
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetFilteredInboxNotificationsByUserID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetForcedMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetForcedMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetForcedMCPServerConfigs indicates an expected call of GetForcedMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetForcedMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForcedMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetForcedMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetGitSSHKey mocks base method.
|
||||
func (m *MockStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2842,6 +2914,96 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigByID mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigByID indicates an expected call of GetMCPServerConfigByID.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigBySlug mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigBySlug", ctx, slug)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigBySlug indicates an expected call of GetMCPServerConfigBySlug.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigBySlug(ctx, slug any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigBySlug", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigBySlug), ctx, slug)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigs indicates an expected call of GetMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigsByIDs mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigsByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigsByIDs indicates an expected call of GetMCPServerConfigsByIDs.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerUserToken indicates an expected call of GetMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// GetMCPServerUserTokensByUserID mocks base method.
|
||||
func (m *MockStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerUserTokensByUserID", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerUserTokensByUserID indicates an expected call of GetMCPServerUserTokensByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerUserTokensByUserID(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserTokensByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserTokensByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// GetNotificationMessagesByStatus mocks base method.
|
||||
func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5957,6 +6119,21 @@ func (mr *MockStoreMockRecorder) InsertLicense(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertMCPServerConfig mocks base method.
|
||||
func (m *MockStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertMCPServerConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertMCPServerConfig indicates an expected call of InsertMCPServerConfig.
|
||||
func (mr *MockStoreMockRecorder) InsertMCPServerConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMCPServerConfig", reflect.TypeOf((*MockStore)(nil).InsertMCPServerConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertMemoryResourceMonitor mocks base method.
|
||||
func (m *MockStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7256,6 +7433,21 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs mocks base method.
|
||||
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatMCPServerIDs", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs indicates an expected call of UpdateChatMCPServerIDs.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatMCPServerIDs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMCPServerIDs", reflect.TypeOf((*MockStore)(nil).UpdateChatMCPServerIDs), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMessageByID mocks base method.
|
||||
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7449,6 +7641,21 @@ func (mr *MockStoreMockRecorder) UpdateInboxNotificationReadStatus(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInboxNotificationReadStatus", reflect.TypeOf((*MockStore)(nil).UpdateInboxNotificationReadStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateMCPServerConfig mocks base method.
|
||||
func (m *MockStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateMCPServerConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateMCPServerConfig indicates an expected call of UpdateMCPServerConfig.
|
||||
func (mr *MockStoreMockRecorder) UpdateMCPServerConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateMCPServerConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateMemberRoles mocks base method.
|
||||
func (m *MockStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8773,6 +8980,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertMCPServerUserToken indicates an expected call of UpsertMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) UpsertMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertNotificationReportGeneratorLog mocks base method.
|
||||
func (m *MockStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+94
-1
@@ -1393,7 +1393,8 @@ CREATE TABLE chats (
|
||||
last_model_config_id uuid NOT NULL,
|
||||
archived boolean DEFAULT false NOT NULL,
|
||||
last_error text,
|
||||
mode chat_mode
|
||||
mode chat_mode,
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -1670,6 +1671,53 @@ CREATE SEQUENCE licenses_id_seq
|
||||
|
||||
ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id;
|
||||
|
||||
CREATE TABLE mcp_server_configs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
display_name text NOT NULL,
|
||||
slug text NOT NULL,
|
||||
description text DEFAULT ''::text NOT NULL,
|
||||
icon_url text DEFAULT ''::text NOT NULL,
|
||||
transport text DEFAULT 'streamable_http'::text NOT NULL,
|
||||
url text NOT NULL,
|
||||
auth_type text DEFAULT 'none'::text NOT NULL,
|
||||
oauth2_client_id text DEFAULT ''::text NOT NULL,
|
||||
oauth2_client_secret text DEFAULT ''::text NOT NULL,
|
||||
oauth2_client_secret_key_id text,
|
||||
oauth2_auth_url text DEFAULT ''::text NOT NULL,
|
||||
oauth2_token_url text DEFAULT ''::text NOT NULL,
|
||||
oauth2_scopes text DEFAULT ''::text NOT NULL,
|
||||
api_key_header text DEFAULT 'Authorization'::text NOT NULL,
|
||||
api_key_value text DEFAULT ''::text NOT NULL,
|
||||
api_key_value_key_id text,
|
||||
custom_headers text DEFAULT '{}'::text NOT NULL,
|
||||
custom_headers_key_id text,
|
||||
tool_allow_list text[] DEFAULT '{}'::text[] NOT NULL,
|
||||
tool_deny_list text[] DEFAULT '{}'::text[] NOT NULL,
|
||||
availability text DEFAULT 'default_off'::text NOT NULL,
|
||||
enabled boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
updated_by uuid,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
|
||||
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
|
||||
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
|
||||
);
|
||||
|
||||
CREATE TABLE mcp_server_user_tokens (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
mcp_server_config_id uuid NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
access_token text NOT NULL,
|
||||
access_token_key_id text,
|
||||
refresh_token text DEFAULT ''::text NOT NULL,
|
||||
refresh_token_key_id text,
|
||||
token_type text DEFAULT 'Bearer'::text NOT NULL,
|
||||
expiry timestamp with time zone,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE notification_messages (
|
||||
id uuid NOT NULL,
|
||||
notification_template_id uuid NOT NULL,
|
||||
@@ -3343,6 +3391,18 @@ ALTER TABLE ONLY licenses
|
||||
ALTER TABLE ONLY licenses
|
||||
ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY notification_messages
|
||||
ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3691,6 +3751,12 @@ CREATE INDEX idx_inbox_notifications_user_id_read_at ON inbox_notifications USIN
|
||||
|
||||
CREATE INDEX idx_inbox_notifications_user_id_template_id_targets ON inbox_notifications USING btree (user_id, template_id, targets);
|
||||
|
||||
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (enabled) WHERE (enabled = true);
|
||||
|
||||
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text));
|
||||
|
||||
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status);
|
||||
|
||||
CREATE INDEX idx_organization_member_organization_id_uuid ON organization_members USING btree (organization_id);
|
||||
@@ -4015,6 +4081,33 @@ ALTER TABLE ONLY jfrog_xray_scans
|
||||
ALTER TABLE ONLY jfrog_xray_scans
|
||||
ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY notification_messages
|
||||
ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -40,6 +40,15 @@ const (
|
||||
ForeignKeyInboxNotificationsUserID ForeignKeyConstraint = "inbox_notifications_user_id_fkey" // ALTER TABLE ONLY inbox_notifications ADD CONSTRAINT inbox_notifications_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
|
||||
ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyMcpServerConfigsAPIKeyValueKeyID ForeignKeyConstraint = "mcp_server_configs_api_key_value_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsCreatedBy ForeignKeyConstraint = "mcp_server_configs_created_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
|
||||
ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerUserTokensUserID ForeignKeyConstraint = "mcp_server_user_tokens_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationMessagesNotificationTemplateID ForeignKeyConstraint = "notification_messages_notification_template_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationMessagesUserID ForeignKeyConstraint = "notification_messages_user_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationPreferencesNotificationTemplateID ForeignKeyConstraint = "notification_preferences_notification_template_id_fkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS mcp_server_ids;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_configs_enabled;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_configs_forced;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_user_tokens_user_id;
|
||||
DROP TABLE IF EXISTS mcp_server_user_tokens;
|
||||
DROP TABLE IF EXISTS mcp_server_configs;
|
||||
@@ -0,0 +1,75 @@
|
||||
CREATE TABLE mcp_server_configs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
|
||||
-- Display
|
||||
display_name TEXT NOT NULL,
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
icon_url TEXT NOT NULL DEFAULT '',
|
||||
|
||||
-- Connection
|
||||
transport TEXT NOT NULL DEFAULT 'streamable_http'
|
||||
CHECK (transport IN ('streamable_http', 'sse')),
|
||||
url TEXT NOT NULL,
|
||||
|
||||
-- Authentication
|
||||
auth_type TEXT NOT NULL DEFAULT 'none'
|
||||
CHECK (auth_type IN ('none', 'oauth2', 'api_key', 'custom_headers')),
|
||||
|
||||
-- OAuth2 config (when auth_type = 'oauth2')
|
||||
oauth2_client_id TEXT NOT NULL DEFAULT '',
|
||||
oauth2_client_secret TEXT NOT NULL DEFAULT '',
|
||||
oauth2_client_secret_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
oauth2_auth_url TEXT NOT NULL DEFAULT '',
|
||||
oauth2_token_url TEXT NOT NULL DEFAULT '',
|
||||
oauth2_scopes TEXT NOT NULL DEFAULT '',
|
||||
|
||||
-- API key config (when auth_type = 'api_key')
|
||||
api_key_header TEXT NOT NULL DEFAULT 'Authorization',
|
||||
api_key_value TEXT NOT NULL DEFAULT '',
|
||||
api_key_value_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
|
||||
-- Custom headers (when auth_type = 'custom_headers')
|
||||
custom_headers TEXT NOT NULL DEFAULT '{}',
|
||||
custom_headers_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
|
||||
-- Tool governance
|
||||
tool_allow_list TEXT[] NOT NULL DEFAULT '{}',
|
||||
tool_deny_list TEXT[] NOT NULL DEFAULT '{}',
|
||||
|
||||
-- Availability policy
|
||||
availability TEXT NOT NULL DEFAULT 'default_off'
|
||||
CHECK (availability IN ('force_on', 'default_on', 'default_off')),
|
||||
|
||||
-- Lifecycle
|
||||
enabled BOOLEAN NOT NULL DEFAULT false,
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE mcp_server_user_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
mcp_server_config_id UUID NOT NULL REFERENCES mcp_server_configs(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
|
||||
access_token TEXT NOT NULL,
|
||||
access_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
refresh_token TEXT NOT NULL DEFAULT '',
|
||||
refresh_token_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
token_type TEXT NOT NULL DEFAULT 'Bearer',
|
||||
expiry TIMESTAMPTZ,
|
||||
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
|
||||
UNIQUE (mcp_server_config_id, user_id)
|
||||
);
|
||||
|
||||
-- Add MCP server selection to chats (per-chat, like model_config_id)
|
||||
ALTER TABLE chats ADD COLUMN mcp_server_ids UUID[] NOT NULL DEFAULT '{}';
|
||||
|
||||
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs(enabled) WHERE enabled = TRUE;
|
||||
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs(enabled, availability) WHERE enabled = TRUE AND availability = 'force_on';
|
||||
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens(user_id);
|
||||
+48
@@ -0,0 +1,48 @@
|
||||
INSERT INTO mcp_server_configs (
|
||||
id,
|
||||
display_name,
|
||||
slug,
|
||||
url,
|
||||
transport,
|
||||
auth_type,
|
||||
availability,
|
||||
enabled,
|
||||
created_by,
|
||||
updated_by,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
|
||||
'Fixture MCP Server',
|
||||
'fixture-mcp-server',
|
||||
'https://mcp.example.com/sse',
|
||||
'sse',
|
||||
'none',
|
||||
'default_on',
|
||||
TRUE,
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
);
|
||||
|
||||
INSERT INTO mcp_server_user_tokens (
|
||||
id,
|
||||
mcp_server_config_id,
|
||||
user_id,
|
||||
access_token,
|
||||
token_type,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
'b2c3d4e5-f6a7-8901-bcde-f12345678901',
|
||||
'a1b2c3d4-e5f6-7890-abcd-ef1234567890',
|
||||
id,
|
||||
'fixture-access-token',
|
||||
'Bearer',
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:00+00'
|
||||
FROM users
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
@@ -787,6 +787,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4167,6 +4167,7 @@ type Chat struct {
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
@@ -4452,6 +4453,50 @@ type License struct {
|
||||
UUID uuid.UUID `db:"uuid" json:"uuid"`
|
||||
}
|
||||
|
||||
type MCPServerConfig struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Slug string `db:"slug" json:"slug"`
|
||||
Description string `db:"description" json:"description"`
|
||||
IconURL string `db:"icon_url" json:"icon_url"`
|
||||
Transport string `db:"transport" json:"transport"`
|
||||
Url string `db:"url" json:"url"`
|
||||
AuthType string `db:"auth_type" json:"auth_type"`
|
||||
OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"`
|
||||
OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"`
|
||||
OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"`
|
||||
OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"`
|
||||
OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"`
|
||||
OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"`
|
||||
APIKeyHeader string `db:"api_key_header" json:"api_key_header"`
|
||||
APIKeyValue string `db:"api_key_value" json:"api_key_value"`
|
||||
APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"`
|
||||
CustomHeaders string `db:"custom_headers" json:"custom_headers"`
|
||||
CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"`
|
||||
ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"`
|
||||
ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"`
|
||||
Availability string `db:"availability" json:"availability"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type MCPServerUserToken struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
AccessToken string `db:"access_token" json:"access_token"`
|
||||
AccessTokenKeyID sql.NullString `db:"access_token_key_id" json:"access_token_key_id"`
|
||||
RefreshToken string `db:"refresh_token" json:"refresh_token"`
|
||||
RefreshTokenKeyID sql.NullString `db:"refresh_token_key_id" json:"refresh_token_key_id"`
|
||||
TokenType string `db:"token_type" json:"token_type"`
|
||||
Expiry sql.NullTime `db:"expiry" json:"expiry"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type NotificationMessage struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
NotificationTemplateID uuid.UUID `db:"notification_template_id" json:"notification_template_id"`
|
||||
|
||||
@@ -74,6 +74,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetCoordinators(ctx context.Context) error
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
@@ -110,6 +111,8 @@ type sqlcQuerier interface {
|
||||
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
|
||||
DeleteLicense(ctx context.Context, id int32) (int32, error)
|
||||
DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error
|
||||
DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -269,6 +272,7 @@ type sqlcQuerier interface {
|
||||
GetEligibleProvisionerDaemonsByProvisionerJobIDs(ctx context.Context, provisionerJobIds []uuid.UUID) ([]GetEligibleProvisionerDaemonsByProvisionerJobIDsRow, error)
|
||||
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
|
||||
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
|
||||
GetFailedWorkspaceBuildsByTemplateID(ctx context.Context, arg GetFailedWorkspaceBuildsByTemplateIDParams) ([]GetFailedWorkspaceBuildsByTemplateIDRow, error)
|
||||
@@ -284,6 +288,7 @@ type sqlcQuerier interface {
|
||||
// param created_at_opt: The created_at timestamp to filter by. This parameter is usd for pagination - it fetches notifications created before the specified timestamp if it is not the zero value
|
||||
// param limit_opt: The limit of notifications to fetch. If the limit is not specified, it defaults to 25
|
||||
GetFilteredInboxNotificationsByUserID(ctx context.Context, arg GetFilteredInboxNotificationsByUserIDParams) ([]InboxNotification, error)
|
||||
GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
|
||||
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
|
||||
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
|
||||
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
|
||||
@@ -312,6 +317,12 @@ type sqlcQuerier interface {
|
||||
GetLicenseByID(ctx context.Context, id int32) (License, error)
|
||||
GetLicenses(ctx context.Context) ([]License, error)
|
||||
GetLogoURL(ctx context.Context) (string, error)
|
||||
GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error)
|
||||
GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error)
|
||||
GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
|
||||
GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error)
|
||||
GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error)
|
||||
GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error)
|
||||
GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error)
|
||||
// Fetch the notification report generator log indicating recent activity.
|
||||
GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error)
|
||||
@@ -675,6 +686,7 @@ type sqlcQuerier interface {
|
||||
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
|
||||
InsertInboxNotification(ctx context.Context, arg InsertInboxNotificationParams) (InboxNotification, error)
|
||||
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
|
||||
InsertMCPServerConfig(ctx context.Context, arg InsertMCPServerConfigParams) (MCPServerConfig, error)
|
||||
InsertMemoryResourceMonitor(ctx context.Context, arg InsertMemoryResourceMonitorParams) (WorkspaceAgentMemoryResourceMonitor, error)
|
||||
// Inserts any group by name that does not exist. All new groups are given
|
||||
// a random uuid, are inserted into the same organization. They have the default
|
||||
@@ -796,6 +808,7 @@ type sqlcQuerier interface {
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
|
||||
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
@@ -813,6 +826,7 @@ type sqlcQuerier interface {
|
||||
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
|
||||
UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error)
|
||||
UpdateInboxNotificationReadStatus(ctx context.Context, arg UpdateInboxNotificationReadStatusParams) error
|
||||
UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPServerConfigParams) (MCPServerConfig, error)
|
||||
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
|
||||
UpdateMemoryResourceMonitor(ctx context.Context, arg UpdateMemoryResourceMonitorParams) error
|
||||
UpdateNotificationTemplateMethodByID(ctx context.Context, arg UpdateNotificationTemplateMethodByIDParams) (NotificationTemplate, error)
|
||||
@@ -917,6 +931,7 @@ type sqlcQuerier interface {
|
||||
UpsertHealthSettings(ctx context.Context, value string) error
|
||||
UpsertLastUpdateCheck(ctx context.Context, value string) error
|
||||
UpsertLogoURL(ctx context.Context, value string) error
|
||||
UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error)
|
||||
// Insert or update notification report generator logs with recent activity.
|
||||
UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error
|
||||
UpsertNotificationsSettings(ctx context.Context, value string) error
|
||||
|
||||
@@ -1653,12 +1653,12 @@ func TestDefaultProxy(t *testing.T) {
|
||||
require.NoError(t, err, "get def proxy")
|
||||
|
||||
require.Equal(t, defProxy.DisplayName, "Default")
|
||||
require.Equal(t, defProxy.IconUrl, "/emojis/1f3e1.png")
|
||||
require.Equal(t, defProxy.IconURL, "/emojis/1f3e1.png")
|
||||
|
||||
// Set the proxy values
|
||||
args := database.UpsertDefaultProxyParams{
|
||||
DisplayName: "displayname",
|
||||
IconUrl: "/icon.png",
|
||||
IconURL: "/icon.png",
|
||||
}
|
||||
err = db.UpsertDefaultProxy(ctx, args)
|
||||
require.NoError(t, err, "insert def proxy")
|
||||
@@ -1666,12 +1666,12 @@ func TestDefaultProxy(t *testing.T) {
|
||||
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
||||
require.NoError(t, err, "get def proxy")
|
||||
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
||||
require.Equal(t, defProxy.IconUrl, args.IconUrl)
|
||||
require.Equal(t, defProxy.IconURL, args.IconURL)
|
||||
|
||||
// Upsert values
|
||||
args = database.UpsertDefaultProxyParams{
|
||||
DisplayName: "newdisplayname",
|
||||
IconUrl: "/newicon.png",
|
||||
IconURL: "/newicon.png",
|
||||
}
|
||||
err = db.UpsertDefaultProxy(ctx, args)
|
||||
require.NoError(t, err, "upsert def proxy")
|
||||
@@ -1679,7 +1679,7 @@ func TestDefaultProxy(t *testing.T) {
|
||||
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
||||
require.NoError(t, err, "get def proxy")
|
||||
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
||||
require.Equal(t, defProxy.IconUrl, args.IconUrl)
|
||||
require.Equal(t, defProxy.IconURL, args.IconURL)
|
||||
|
||||
// Ensure other site configs are the same
|
||||
found, err := db.GetDeploymentID(ctx)
|
||||
|
||||
+864
-15
File diff suppressed because it is too large
Load Diff
@@ -180,7 +180,8 @@ INSERT INTO chats (
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
title,
|
||||
mode
|
||||
mode,
|
||||
mcp_server_ids
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -188,7 +189,8 @@ INSERT INTO chats (
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@last_model_config_id::uuid,
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[])
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -295,6 +297,17 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
mcp_server_ids = @mcp_server_ids::uuid[],
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: AcquireChats :many
|
||||
-- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED
|
||||
-- to prevent multiple replicas from acquiring the same chat.
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
-- name: GetMCPServerConfigByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetMCPServerConfigBySlug :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
slug = @slug::text;
|
||||
|
||||
-- name: GetMCPServerConfigs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
ORDER BY
|
||||
display_name ASC;
|
||||
|
||||
-- name: GetEnabledMCPServerConfigs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
ORDER BY
|
||||
display_name ASC;
|
||||
|
||||
-- name: GetMCPServerConfigsByIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
id = ANY(@ids::uuid[])
|
||||
ORDER BY
|
||||
display_name ASC;
|
||||
|
||||
-- name: GetForcedMCPServerConfigs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
AND availability = 'force_on'
|
||||
ORDER BY
|
||||
display_name ASC;
|
||||
|
||||
-- name: InsertMCPServerConfig :one
|
||||
INSERT INTO mcp_server_configs (
|
||||
display_name,
|
||||
slug,
|
||||
description,
|
||||
icon_url,
|
||||
transport,
|
||||
url,
|
||||
auth_type,
|
||||
oauth2_client_id,
|
||||
oauth2_client_secret,
|
||||
oauth2_client_secret_key_id,
|
||||
oauth2_auth_url,
|
||||
oauth2_token_url,
|
||||
oauth2_scopes,
|
||||
api_key_header,
|
||||
api_key_value,
|
||||
api_key_value_key_id,
|
||||
custom_headers,
|
||||
custom_headers_key_id,
|
||||
tool_allow_list,
|
||||
tool_deny_list,
|
||||
availability,
|
||||
enabled,
|
||||
created_by,
|
||||
updated_by
|
||||
) VALUES (
|
||||
@display_name::text,
|
||||
@slug::text,
|
||||
@description::text,
|
||||
@icon_url::text,
|
||||
@transport::text,
|
||||
@url::text,
|
||||
@auth_type::text,
|
||||
@oauth2_client_id::text,
|
||||
@oauth2_client_secret::text,
|
||||
sqlc.narg('oauth2_client_secret_key_id')::text,
|
||||
@oauth2_auth_url::text,
|
||||
@oauth2_token_url::text,
|
||||
@oauth2_scopes::text,
|
||||
@api_key_header::text,
|
||||
@api_key_value::text,
|
||||
sqlc.narg('api_key_value_key_id')::text,
|
||||
@custom_headers::text,
|
||||
sqlc.narg('custom_headers_key_id')::text,
|
||||
@tool_allow_list::text[],
|
||||
@tool_deny_list::text[],
|
||||
@availability::text,
|
||||
@enabled::boolean,
|
||||
@created_by::uuid,
|
||||
@updated_by::uuid
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateMCPServerConfig :one
|
||||
UPDATE
|
||||
mcp_server_configs
|
||||
SET
|
||||
display_name = @display_name::text,
|
||||
slug = @slug::text,
|
||||
description = @description::text,
|
||||
icon_url = @icon_url::text,
|
||||
transport = @transport::text,
|
||||
url = @url::text,
|
||||
auth_type = @auth_type::text,
|
||||
oauth2_client_id = @oauth2_client_id::text,
|
||||
oauth2_client_secret = @oauth2_client_secret::text,
|
||||
oauth2_client_secret_key_id = sqlc.narg('oauth2_client_secret_key_id')::text,
|
||||
oauth2_auth_url = @oauth2_auth_url::text,
|
||||
oauth2_token_url = @oauth2_token_url::text,
|
||||
oauth2_scopes = @oauth2_scopes::text,
|
||||
api_key_header = @api_key_header::text,
|
||||
api_key_value = @api_key_value::text,
|
||||
api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text,
|
||||
custom_headers = @custom_headers::text,
|
||||
custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text,
|
||||
tool_allow_list = @tool_allow_list::text[],
|
||||
tool_deny_list = @tool_deny_list::text[],
|
||||
availability = @availability::text,
|
||||
enabled = @enabled::boolean,
|
||||
updated_by = @updated_by::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteMCPServerConfigByID :exec
|
||||
DELETE FROM
|
||||
mcp_server_configs
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetMCPServerUserToken :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_user_tokens
|
||||
WHERE
|
||||
mcp_server_config_id = @mcp_server_config_id::uuid
|
||||
AND user_id = @user_id::uuid;
|
||||
|
||||
-- name: GetMCPServerUserTokensByUserID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
mcp_server_user_tokens
|
||||
WHERE
|
||||
user_id = @user_id::uuid;
|
||||
|
||||
-- name: UpsertMCPServerUserToken :one
|
||||
INSERT INTO mcp_server_user_tokens (
|
||||
mcp_server_config_id,
|
||||
user_id,
|
||||
access_token,
|
||||
access_token_key_id,
|
||||
refresh_token,
|
||||
refresh_token_key_id,
|
||||
token_type,
|
||||
expiry
|
||||
) VALUES (
|
||||
@mcp_server_config_id::uuid,
|
||||
@user_id::uuid,
|
||||
@access_token::text,
|
||||
sqlc.narg('access_token_key_id')::text,
|
||||
@refresh_token::text,
|
||||
sqlc.narg('refresh_token_key_id')::text,
|
||||
@token_type::text,
|
||||
sqlc.narg('expiry')::timestamptz
|
||||
)
|
||||
ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET
|
||||
access_token = @access_token::text,
|
||||
access_token_key_id = sqlc.narg('access_token_key_id')::text,
|
||||
refresh_token = @refresh_token::text,
|
||||
refresh_token_key_id = sqlc.narg('refresh_token_key_id')::text,
|
||||
token_type = @token_type::text,
|
||||
expiry = sqlc.narg('expiry')::timestamptz,
|
||||
updated_at = NOW()
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteMCPServerUserToken :exec
|
||||
DELETE FROM
|
||||
mcp_server_user_tokens
|
||||
WHERE
|
||||
mcp_server_config_id = @mcp_server_config_id::uuid
|
||||
AND user_id = @user_id::uuid;
|
||||
|
||||
-- name: CleanupDeletedMCPServerIDsFromChats :exec
|
||||
UPDATE chats
|
||||
SET mcp_server_ids = (
|
||||
SELECT COALESCE(array_agg(sid), '{}')
|
||||
FROM unnest(chats.mcp_server_ids) AS sid
|
||||
WHERE sid IN (SELECT id FROM mcp_server_configs)
|
||||
)
|
||||
WHERE mcp_server_ids != '{}'
|
||||
AND NOT (mcp_server_ids <@ COALESCE((SELECT array_agg(id) FROM mcp_server_configs), '{}'));
|
||||
@@ -236,6 +236,28 @@ sql:
|
||||
aibridge_token_usage: AIBridgeTokenUsage
|
||||
aibridge_user_prompt: AIBridgeUserPrompt
|
||||
aibridge_model_thought: AIBridgeModelThought
|
||||
mcp_server_config: MCPServerConfig
|
||||
mcp_server_configs: MCPServerConfigs
|
||||
mcp_server_user_token: MCPServerUserToken
|
||||
mcp_server_user_tokens: MCPServerUserTokens
|
||||
mcp_server_tool_snapshot: MCPServerToolSnapshot
|
||||
mcp_server_tool_snapshots: MCPServerToolSnapshots
|
||||
mcp_server_config_id: MCPServerConfigID
|
||||
mcp_server_ids: MCPServerIDs
|
||||
icon_url: IconURL
|
||||
oauth2_client_id: OAuth2ClientID
|
||||
oauth2_client_secret: OAuth2ClientSecret
|
||||
oauth2_client_secret_key_id: OAuth2ClientSecretKeyID
|
||||
oauth2_auth_url: OAuth2AuthURL
|
||||
oauth2_token_url: OAuth2TokenURL
|
||||
oauth2_scopes: OAuth2Scopes
|
||||
api_key_header: APIKeyHeader
|
||||
api_key_value: APIKeyValue
|
||||
api_key_value_key_id: APIKeyValueKeyID
|
||||
custom_headers_key_id: CustomHeadersKeyID
|
||||
tools_json: ToolsJSON
|
||||
access_token_key_id: AccessTokenKeyID
|
||||
refresh_token_key_id: RefreshTokenKeyID
|
||||
rules:
|
||||
- name: do-not-use-public-schema-in-queries
|
||||
message: "do not use public schema in queries"
|
||||
|
||||
@@ -42,6 +42,10 @@ const (
|
||||
UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id);
|
||||
UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt);
|
||||
UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
|
||||
UniqueMcpServerConfigsPkey UniqueConstraint = "mcp_server_configs_pkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
|
||||
UniqueMcpServerConfigsSlugKey UniqueConstraint = "mcp_server_configs_slug_key" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
|
||||
UniqueMcpServerUserTokensMcpServerConfigIDUserIDKey UniqueConstraint = "mcp_server_user_tokens_mcp_server_config_id_user_id_key" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
|
||||
UniqueMcpServerUserTokensPkey UniqueConstraint = "mcp_server_user_tokens_pkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
|
||||
UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
|
||||
UniqueNotificationPreferencesPkey UniqueConstraint = "notification_preferences_pkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_pkey PRIMARY KEY (user_id, notification_template_id);
|
||||
UniqueNotificationReportGeneratorLogsPkey UniqueConstraint = "notification_report_generator_logs_pkey" // ALTER TABLE ONLY notification_report_generator_logs ADD CONSTRAINT notification_report_generator_logs_pkey PRIMARY KEY (notification_template_id);
|
||||
|
||||
+921
@@ -0,0 +1,921 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary List MCP server configs
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) listMCPServerConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Admin users can see all MCP server configs (including disabled
|
||||
// ones) for management purposes. Non-admin users see only enabled
|
||||
// configs, which is sufficient for using the chat feature.
|
||||
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
||||
|
||||
var configs []database.MCPServerConfig
|
||||
var err error
|
||||
if isAdmin {
|
||||
configs, err = api.Database.GetMCPServerConfigs(ctx)
|
||||
} else {
|
||||
//nolint:gocritic // All authenticated users need to read enabled MCP server configs to use the chat feature.
|
||||
configs, err = api.Database.GetEnabledMCPServerConfigs(dbauthz.AsSystemRestricted(ctx))
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list MCP server configs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the calling user's OAuth2 tokens so we can populate
|
||||
// auth_connected per server.
|
||||
//nolint:gocritic // Need to check user tokens across all servers.
|
||||
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get user tokens.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
tokenMap := make(map[uuid.UUID]bool, len(userTokens))
|
||||
for _, t := range userTokens {
|
||||
tokenMap[t.MCPServerConfigID] = true
|
||||
}
|
||||
|
||||
resp := make([]codersdk.MCPServerConfig, 0, len(configs))
|
||||
for _, config := range configs {
|
||||
var sdkConfig codersdk.MCPServerConfig
|
||||
if isAdmin {
|
||||
sdkConfig = convertMCPServerConfig(config)
|
||||
} else {
|
||||
sdkConfig = convertMCPServerConfigRedacted(config)
|
||||
}
|
||||
if config.AuthType == "oauth2" {
|
||||
sdkConfig.AuthConnected = tokenMap[config.ID]
|
||||
} else {
|
||||
sdkConfig.AuthConnected = true
|
||||
}
|
||||
resp = append(resp, sdkConfig)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// @Summary Create MCP server config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.CreateMCPServerConfigRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate auth-type-dependent fields.
|
||||
switch req.AuthType {
|
||||
case "oauth2":
|
||||
if req.OAuth2ClientID == "" || req.OAuth2AuthURL == "" || req.OAuth2TokenURL == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 auth type requires oauth2_client_id, oauth2_auth_url, and oauth2_token_url.",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "api_key":
|
||||
if req.APIKeyHeader == "" || req.APIKeyValue == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key auth type requires api_key_header and api_key_value.",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "custom_headers":
|
||||
if len(req.CustomHeaders) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Custom headers auth type requires at least one custom header.",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
customHeadersJSON, err := marshalCustomHeaders(req.CustomHeaders)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid custom headers.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
Slug: strings.TrimSpace(req.Slug),
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
IconURL: strings.TrimSpace(req.IconURL),
|
||||
Transport: strings.TrimSpace(req.Transport),
|
||||
Url: strings.TrimSpace(req.URL),
|
||||
AuthType: strings.TrimSpace(req.AuthType),
|
||||
OAuth2ClientID: strings.TrimSpace(req.OAuth2ClientID),
|
||||
OAuth2ClientSecret: strings.TrimSpace(req.OAuth2ClientSecret),
|
||||
OAuth2ClientSecretKeyID: sql.NullString{},
|
||||
OAuth2AuthURL: strings.TrimSpace(req.OAuth2AuthURL),
|
||||
OAuth2TokenURL: strings.TrimSpace(req.OAuth2TokenURL),
|
||||
OAuth2Scopes: strings.TrimSpace(req.OAuth2Scopes),
|
||||
APIKeyHeader: strings.TrimSpace(req.APIKeyHeader),
|
||||
APIKeyValue: strings.TrimSpace(req.APIKeyValue),
|
||||
APIKeyValueKeyID: sql.NullString{},
|
||||
CustomHeaders: customHeadersJSON,
|
||||
CustomHeadersKeyID: sql.NullString{},
|
||||
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
|
||||
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
|
||||
Availability: strings.TrimSpace(req.Availability),
|
||||
Enabled: req.Enabled,
|
||||
CreatedBy: apiKey.UserID,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
case database.IsUniqueViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "MCP server config already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
case database.IsCheckViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, convertMCPServerConfig(inserted))
|
||||
}
|
||||
|
||||
// @Summary Get MCP server config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) getMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
isAdmin := api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
||||
|
||||
var config database.MCPServerConfig
|
||||
var err error
|
||||
if isAdmin {
|
||||
config, err = api.Database.GetMCPServerConfigByID(ctx, mcpServerID)
|
||||
} else {
|
||||
//nolint:gocritic // All authenticated users can view enabled MCP server configs.
|
||||
config, err = api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
|
||||
if err == nil && !config.Enabled {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var sdkConfig codersdk.MCPServerConfig
|
||||
if isAdmin {
|
||||
sdkConfig = convertMCPServerConfig(config)
|
||||
} else {
|
||||
sdkConfig = convertMCPServerConfigRedacted(config)
|
||||
}
|
||||
|
||||
// Populate AuthConnected for the calling user.
|
||||
if config.AuthType == "oauth2" {
|
||||
//nolint:gocritic // Need to check user token for this server.
|
||||
userTokens, err := api.Database.GetMCPServerUserTokensByUserID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get user tokens.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, t := range userTokens {
|
||||
if t.MCPServerConfigID == config.ID {
|
||||
sdkConfig.AuthConnected = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sdkConfig.AuthConnected = true
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, sdkConfig)
|
||||
}
|
||||
|
||||
// @Summary Update MCP server config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateMCPServerConfigRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-validate custom headers before entering the transaction.
|
||||
var customHeadersJSON string
|
||||
if req.CustomHeaders != nil {
|
||||
var chErr error
|
||||
customHeadersJSON, chErr = marshalCustomHeaders(*req.CustomHeaders)
|
||||
if chErr != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid custom headers.",
|
||||
Detail: chErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var updated database.MCPServerConfig
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
existing, err := tx.GetMCPServerConfigByID(ctx, mcpServerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
displayName := existing.DisplayName
|
||||
if req.DisplayName != nil {
|
||||
displayName = strings.TrimSpace(*req.DisplayName)
|
||||
}
|
||||
|
||||
slug := existing.Slug
|
||||
if req.Slug != nil {
|
||||
slug = strings.TrimSpace(*req.Slug)
|
||||
}
|
||||
|
||||
description := existing.Description
|
||||
if req.Description != nil {
|
||||
description = strings.TrimSpace(*req.Description)
|
||||
}
|
||||
|
||||
iconURL := existing.IconURL
|
||||
if req.IconURL != nil {
|
||||
iconURL = strings.TrimSpace(*req.IconURL)
|
||||
}
|
||||
|
||||
transport := existing.Transport
|
||||
if req.Transport != nil {
|
||||
transport = strings.TrimSpace(*req.Transport)
|
||||
}
|
||||
|
||||
serverURL := existing.Url
|
||||
if req.URL != nil {
|
||||
serverURL = strings.TrimSpace(*req.URL)
|
||||
}
|
||||
|
||||
authType := existing.AuthType
|
||||
if req.AuthType != nil {
|
||||
authType = strings.TrimSpace(*req.AuthType)
|
||||
}
|
||||
|
||||
oauth2ClientID := existing.OAuth2ClientID
|
||||
if req.OAuth2ClientID != nil {
|
||||
oauth2ClientID = strings.TrimSpace(*req.OAuth2ClientID)
|
||||
}
|
||||
|
||||
oauth2ClientSecret := existing.OAuth2ClientSecret
|
||||
oauth2ClientSecretKeyID := existing.OAuth2ClientSecretKeyID
|
||||
if req.OAuth2ClientSecret != nil {
|
||||
oauth2ClientSecret = strings.TrimSpace(*req.OAuth2ClientSecret)
|
||||
// Clear the key ID when the secret is explicitly updated.
|
||||
oauth2ClientSecretKeyID = sql.NullString{}
|
||||
}
|
||||
|
||||
oauth2AuthURL := existing.OAuth2AuthURL
|
||||
if req.OAuth2AuthURL != nil {
|
||||
oauth2AuthURL = strings.TrimSpace(*req.OAuth2AuthURL)
|
||||
}
|
||||
|
||||
oauth2TokenURL := existing.OAuth2TokenURL
|
||||
if req.OAuth2TokenURL != nil {
|
||||
oauth2TokenURL = strings.TrimSpace(*req.OAuth2TokenURL)
|
||||
}
|
||||
|
||||
oauth2Scopes := existing.OAuth2Scopes
|
||||
if req.OAuth2Scopes != nil {
|
||||
oauth2Scopes = strings.TrimSpace(*req.OAuth2Scopes)
|
||||
}
|
||||
|
||||
apiKeyHeader := existing.APIKeyHeader
|
||||
if req.APIKeyHeader != nil {
|
||||
apiKeyHeader = strings.TrimSpace(*req.APIKeyHeader)
|
||||
}
|
||||
|
||||
apiKeyValue := existing.APIKeyValue
|
||||
apiKeyValueKeyID := existing.APIKeyValueKeyID
|
||||
if req.APIKeyValue != nil {
|
||||
apiKeyValue = strings.TrimSpace(*req.APIKeyValue)
|
||||
// Clear the key ID when the value is explicitly updated.
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
}
|
||||
|
||||
customHeaders := existing.CustomHeaders
|
||||
customHeadersKeyID := existing.CustomHeadersKeyID
|
||||
if req.CustomHeaders != nil {
|
||||
customHeaders = customHeadersJSON
|
||||
// Clear the key ID when headers are explicitly updated.
|
||||
customHeadersKeyID = sql.NullString{}
|
||||
}
|
||||
|
||||
toolAllowList := existing.ToolAllowList
|
||||
if req.ToolAllowList != nil {
|
||||
toolAllowList = coalesceStringSlice(trimStringSlice(*req.ToolAllowList))
|
||||
}
|
||||
|
||||
toolDenyList := existing.ToolDenyList
|
||||
if req.ToolDenyList != nil {
|
||||
toolDenyList = coalesceStringSlice(trimStringSlice(*req.ToolDenyList))
|
||||
}
|
||||
|
||||
availability := existing.Availability
|
||||
if req.Availability != nil {
|
||||
availability = strings.TrimSpace(*req.Availability)
|
||||
}
|
||||
|
||||
enabled := existing.Enabled
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
|
||||
// When auth_type changes, clear fields belonging to the
|
||||
// previous auth type so stale secrets don't persist.
|
||||
if authType != existing.AuthType {
|
||||
switch authType {
|
||||
case "none":
|
||||
oauth2ClientID = ""
|
||||
oauth2ClientSecret = ""
|
||||
oauth2ClientSecretKeyID = sql.NullString{}
|
||||
oauth2AuthURL = ""
|
||||
oauth2TokenURL = ""
|
||||
oauth2Scopes = ""
|
||||
apiKeyHeader = ""
|
||||
apiKeyValue = ""
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
customHeaders = "{}"
|
||||
customHeadersKeyID = sql.NullString{}
|
||||
case "oauth2":
|
||||
apiKeyHeader = ""
|
||||
apiKeyValue = ""
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
customHeaders = "{}"
|
||||
customHeadersKeyID = sql.NullString{}
|
||||
case "api_key":
|
||||
oauth2ClientID = ""
|
||||
oauth2ClientSecret = ""
|
||||
oauth2ClientSecretKeyID = sql.NullString{}
|
||||
oauth2AuthURL = ""
|
||||
oauth2TokenURL = ""
|
||||
oauth2Scopes = ""
|
||||
customHeaders = "{}"
|
||||
customHeadersKeyID = sql.NullString{}
|
||||
case "custom_headers":
|
||||
oauth2ClientID = ""
|
||||
oauth2ClientSecret = ""
|
||||
oauth2ClientSecretKeyID = sql.NullString{}
|
||||
oauth2AuthURL = ""
|
||||
oauth2TokenURL = ""
|
||||
oauth2Scopes = ""
|
||||
apiKeyHeader = ""
|
||||
apiKeyValue = ""
|
||||
apiKeyValueKeyID = sql.NullString{}
|
||||
}
|
||||
}
|
||||
|
||||
updated, err = tx.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
|
||||
DisplayName: displayName,
|
||||
Slug: slug,
|
||||
Description: description,
|
||||
IconURL: iconURL,
|
||||
Transport: transport,
|
||||
Url: serverURL,
|
||||
AuthType: authType,
|
||||
OAuth2ClientID: oauth2ClientID,
|
||||
OAuth2ClientSecret: oauth2ClientSecret,
|
||||
OAuth2ClientSecretKeyID: oauth2ClientSecretKeyID,
|
||||
OAuth2AuthURL: oauth2AuthURL,
|
||||
OAuth2TokenURL: oauth2TokenURL,
|
||||
OAuth2Scopes: oauth2Scopes,
|
||||
APIKeyHeader: apiKeyHeader,
|
||||
APIKeyValue: apiKeyValue,
|
||||
APIKeyValueKeyID: apiKeyValueKeyID,
|
||||
CustomHeaders: customHeaders,
|
||||
CustomHeadersKeyID: customHeadersKeyID,
|
||||
ToolAllowList: toolAllowList,
|
||||
ToolDenyList: toolDenyList,
|
||||
Availability: availability,
|
||||
Enabled: enabled,
|
||||
UpdatedBy: apiKey.UserID,
|
||||
ID: existing.ID,
|
||||
})
|
||||
return err
|
||||
}, nil)
|
||||
if err != nil {
|
||||
switch {
|
||||
case httpapi.Is404Error(err):
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
case database.IsUniqueViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "MCP server config slug already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
case database.IsCheckViolation(err):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertMCPServerConfig(updated))
|
||||
}
|
||||
|
||||
// @Summary Delete MCP server config
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) deleteMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Database.GetMCPServerConfigByID(ctx, mcpServerID); err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteMCPServerConfigByID(ctx, mcpServerID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Initiate MCP server OAuth2 connect
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
// Redirects the user to the MCP server's OAuth2 authorization URL.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) mcpServerOAuth2Connect(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // Any authenticated user can initiate OAuth2 for an enabled MCP server.
|
||||
config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !config.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "MCP server is not enabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if config.AuthType != "oauth2" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "MCP server does not use OAuth2 authentication.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if config.OAuth2AuthURL == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "MCP server OAuth2 authorization URL is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Build the authorization URL. The frontend opens this in a popup.
|
||||
// The callback URL is on our server; after the exchange we store
|
||||
// the token and close the popup.
|
||||
state := uuid.New().String()
|
||||
http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
|
||||
Name: "mcp_oauth2_state_" + config.ID.String(),
|
||||
Value: state,
|
||||
Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID),
|
||||
MaxAge: 600, // 10 minutes
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}))
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.OAuth2AuthURL,
|
||||
TokenURL: config.OAuth2TokenURL,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID),
|
||||
}
|
||||
var scopes []string
|
||||
if config.OAuth2Scopes != "" {
|
||||
scopes = strings.Split(config.OAuth2Scopes, " ")
|
||||
}
|
||||
oauth2Config.Scopes = scopes
|
||||
authURL := oauth2Config.AuthCodeURL(state)
|
||||
http.Redirect(rw, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// @Summary Handle MCP server OAuth2 callback
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
// Exchanges the authorization code for tokens and stores them.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) mcpServerOAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // Any authenticated user can complete OAuth2 for an enabled MCP server.
|
||||
config, err := api.Database.GetMCPServerConfigByID(dbauthz.AsSystemRestricted(ctx), mcpServerID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get MCP server config.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !config.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "MCP server is not enabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if config.AuthType != "oauth2" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "MCP server does not use OAuth2 authentication.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the OAuth2 provider returned an error (e.g., user
|
||||
// denied consent).
|
||||
if oauthError := r.URL.Query().Get("error"); oauthError != "" {
|
||||
desc := r.URL.Query().Get("error_description")
|
||||
if desc == "" {
|
||||
desc = oauthError
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "OAuth2 provider returned an error.",
|
||||
Detail: desc,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing authorization code.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the state parameter for CSRF protection.
|
||||
expectedState := ""
|
||||
if cookie, err := r.Cookie("mcp_oauth2_state_" + config.ID.String()); err == nil {
|
||||
expectedState = cookie.Value
|
||||
}
|
||||
actualState := r.URL.Query().Get("state")
|
||||
if expectedState == "" || actualState != expectedState {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid or missing OAuth2 state parameter.",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Clear the state cookie.
|
||||
http.SetCookie(rw, api.DeploymentValues.HTTPCookies.Apply(&http.Cookie{
|
||||
Name: "mcp_oauth2_state_" + config.ID.String(),
|
||||
Value: "",
|
||||
Path: fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/callback", config.ID),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}))
|
||||
|
||||
// Exchange the authorization code for tokens.
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.OAuth2ClientID,
|
||||
ClientSecret: config.OAuth2ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.OAuth2AuthURL,
|
||||
TokenURL: config.OAuth2TokenURL,
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/callback", api.AccessURL.String(), config.ID),
|
||||
}
|
||||
var scopes []string
|
||||
if config.OAuth2Scopes != "" {
|
||||
scopes = strings.Split(config.OAuth2Scopes, " ")
|
||||
}
|
||||
oauth2Config.Scopes = scopes
|
||||
|
||||
// Use the deployment's HTTP client for the token exchange to
|
||||
// respect proxy settings and avoid using http.DefaultClient.
|
||||
exchangeCtx := context.WithValue(ctx, oauth2.HTTPClient, api.HTTPClient)
|
||||
token, err := oauth2Config.Exchange(exchangeCtx, code)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadGateway, codersdk.Response{
|
||||
Message: "Failed to exchange authorization code for token.",
|
||||
Detail: "The OAuth2 token exchange with the upstream provider failed.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Store the token for the user.
|
||||
refreshToken := ""
|
||||
if token.RefreshToken != "" {
|
||||
refreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
var expiry sql.NullTime
|
||||
if !token.Expiry.IsZero() {
|
||||
expiry = sql.NullTime{Time: token.Expiry, Valid: true}
|
||||
}
|
||||
|
||||
//nolint:gocritic // Users store their own tokens.
|
||||
_, err = api.Database.UpsertMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: mcpServerID,
|
||||
UserID: apiKey.UserID,
|
||||
AccessToken: token.AccessToken,
|
||||
AccessTokenKeyID: sql.NullString{},
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenKeyID: sql.NullString{},
|
||||
TokenType: token.TokenType,
|
||||
Expiry: expiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to store OAuth2 token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Respond with a simple HTML page that closes the popup window.
|
||||
rw.Header().Set("Content-Security-Policy", "default-src 'none'; script-src 'unsafe-inline'")
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte(`<!DOCTYPE html><html><body><script>
|
||||
if (window.opener) {
|
||||
window.opener.postMessage({type: "mcp-oauth2-complete", serverID: "` + config.ID.String() + `"}, "` + api.AccessURL.String() + `");
|
||||
window.close();
|
||||
} else {
|
||||
document.body.innerText = "Authentication successful. You may close this window.";
|
||||
}
|
||||
</script></body></html>`))
|
||||
}
|
||||
|
||||
// @Summary Disconnect MCP server OAuth2 token
|
||||
// @x-apidocgen {"skip": true}
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
// Removes the user's stored OAuth2 token for an MCP server.
|
||||
func (api *API) mcpServerOAuth2Disconnect(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
mcpServerID, ok := parseMCPServerConfigID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // Users manage their own tokens.
|
||||
err := api.Database.DeleteMCPServerUserToken(dbauthz.AsSystemRestricted(ctx), database.DeleteMCPServerUserTokenParams{
|
||||
MCPServerConfigID: mcpServerID,
|
||||
UserID: apiKey.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to disconnect OAuth2 token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// parseMCPServerConfigID extracts the MCP server config UUID from the
|
||||
// "mcpServer" path parameter.
|
||||
func parseMCPServerConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
|
||||
mcpServerID, err := uuid.Parse(chi.URLParam(r, "mcpServer"))
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid MCP server config ID.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return uuid.Nil, false
|
||||
}
|
||||
return mcpServerID, true
|
||||
}
|
||||
|
||||
// convertMCPServerConfig converts a database MCP server config to the
|
||||
// SDK type. Secrets are never returned; only has_* booleans are set.
|
||||
// Admin-only fields (OAuth2 client ID, auth URLs, etc.) are included.
|
||||
func convertMCPServerConfig(config database.MCPServerConfig) codersdk.MCPServerConfig {
|
||||
return codersdk.MCPServerConfig{
|
||||
ID: config.ID,
|
||||
DisplayName: config.DisplayName,
|
||||
Slug: config.Slug,
|
||||
Description: config.Description,
|
||||
IconURL: config.IconURL,
|
||||
|
||||
Transport: config.Transport,
|
||||
URL: config.Url,
|
||||
|
||||
AuthType: config.AuthType,
|
||||
OAuth2ClientID: config.OAuth2ClientID,
|
||||
HasOAuth2Secret: config.OAuth2ClientSecret != "",
|
||||
OAuth2AuthURL: config.OAuth2AuthURL,
|
||||
OAuth2TokenURL: config.OAuth2TokenURL,
|
||||
OAuth2Scopes: config.OAuth2Scopes,
|
||||
|
||||
APIKeyHeader: config.APIKeyHeader,
|
||||
HasAPIKey: config.APIKeyValue != "",
|
||||
|
||||
HasCustomHeaders: len(config.CustomHeaders) > 0 && config.CustomHeaders != "{}",
|
||||
|
||||
ToolAllowList: coalesceStringSlice(config.ToolAllowList),
|
||||
ToolDenyList: coalesceStringSlice(config.ToolDenyList),
|
||||
|
||||
Availability: config.Availability,
|
||||
|
||||
Enabled: config.Enabled,
|
||||
CreatedAt: config.CreatedAt,
|
||||
UpdatedAt: config.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// convertMCPServerConfigRedacted is the same as convertMCPServerConfig
|
||||
// but strips admin-only fields (OAuth2 details, API key header) for
|
||||
// non-admin callers.
|
||||
func convertMCPServerConfigRedacted(config database.MCPServerConfig) codersdk.MCPServerConfig {
|
||||
c := convertMCPServerConfig(config)
|
||||
c.URL = ""
|
||||
c.Transport = ""
|
||||
c.OAuth2ClientID = ""
|
||||
c.OAuth2AuthURL = ""
|
||||
c.OAuth2TokenURL = ""
|
||||
c.OAuth2Scopes = ""
|
||||
c.APIKeyHeader = ""
|
||||
return c
|
||||
}
|
||||
|
||||
// marshalCustomHeaders encodes a map of custom headers to JSON for
|
||||
// database storage. A nil map produces an empty JSON object.
|
||||
func marshalCustomHeaders(headers map[string]string) (string, error) {
|
||||
if headers == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
encoded, err := json.Marshal(headers)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
// trimStringSlice trims whitespace from each element and drops empty
|
||||
// strings.
|
||||
func trimStringSlice(ss []string) []string {
|
||||
if ss == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// coalesceStringSlice returns ss if non-nil, otherwise an empty
|
||||
// non-nil slice. This prevents pq.Array from sending NULL for
|
||||
// NOT NULL text[] columns.
|
||||
func coalesceStringSlice(ss []string) []string {
|
||||
if ss == nil {
|
||||
return []string{}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
@@ -0,0 +1,489 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mcpDeploymentValues returns deployment values with the agents
|
||||
// experiment enabled, which is required by the MCP server config
|
||||
// endpoints.
|
||||
func mcpDeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||
t.Helper()
|
||||
|
||||
values := coderdtest.DeploymentValues(t)
|
||||
values.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
return values
|
||||
}
|
||||
|
||||
// newMCPClient creates a test server with the agents experiment
|
||||
// enabled and returns the admin client.
|
||||
func newMCPClient(t testing.TB) *codersdk.Client {
|
||||
t.Helper()
|
||||
|
||||
return coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: mcpDeploymentValues(t),
|
||||
})
|
||||
}
|
||||
|
||||
// createMCPServerConfig is a helper that creates a minimal enabled
|
||||
// MCP server config with auth_type=none.
|
||||
func createMCPServerConfig(t testing.TB, client *codersdk.Client, slug string, enabled bool) codersdk.MCPServerConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
config, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Test Server " + slug,
|
||||
Slug: slug,
|
||||
Description: "A test MCP server.",
|
||||
IconURL: "https://example.com/icon.png",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/" + slug,
|
||||
AuthType: "none",
|
||||
Availability: "default_on",
|
||||
Enabled: enabled,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return config
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsCRUD(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a config with all fields populated including OAuth2
|
||||
// secrets so we can verify they are not leaked.
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "My MCP Server",
|
||||
Slug: "my-mcp-server",
|
||||
Description: "Integration test server.",
|
||||
IconURL: "https://example.com/icon.png",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/v1",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "client-id-123",
|
||||
OAuth2ClientSecret: "super-secret-value",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
OAuth2Scopes: "read write",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, created.ID)
|
||||
require.Equal(t, "My MCP Server", created.DisplayName)
|
||||
require.Equal(t, "my-mcp-server", created.Slug)
|
||||
require.Equal(t, "Integration test server.", created.Description)
|
||||
require.Equal(t, "streamable_http", created.Transport)
|
||||
require.Equal(t, "https://mcp.example.com/v1", created.URL)
|
||||
require.Equal(t, "oauth2", created.AuthType)
|
||||
require.Equal(t, "client-id-123", created.OAuth2ClientID)
|
||||
require.Equal(t, "default_on", created.Availability)
|
||||
require.True(t, created.Enabled)
|
||||
|
||||
// Verify the secret is indicated but never returned.
|
||||
require.True(t, created.HasOAuth2Secret)
|
||||
|
||||
// Verify the config appears in the list.
|
||||
configs, err := client.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, configs, 1)
|
||||
require.Equal(t, created.ID, configs[0].ID)
|
||||
require.True(t, configs[0].HasOAuth2Secret)
|
||||
|
||||
// Update display name and availability.
|
||||
newName := "Renamed Server"
|
||||
newAvail := "force_on"
|
||||
updated, err := client.UpdateMCPServerConfig(ctx, created.ID, codersdk.UpdateMCPServerConfigRequest{
|
||||
DisplayName: &newName,
|
||||
Availability: &newAvail,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Renamed Server", updated.DisplayName)
|
||||
require.Equal(t, "force_on", updated.Availability)
|
||||
// Unchanged fields should remain the same.
|
||||
require.Equal(t, "my-mcp-server", updated.Slug)
|
||||
require.Equal(t, "oauth2", updated.AuthType)
|
||||
|
||||
// Verify the update took effect through the list.
|
||||
configs, err = client.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, configs, 1)
|
||||
require.Equal(t, "Renamed Server", configs[0].DisplayName)
|
||||
require.Equal(t, "force_on", configs[0].Availability)
|
||||
|
||||
// Delete it.
|
||||
err = client.DeleteMCPServerConfig(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
configs, err = client.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, configs)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsNonAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newMCPClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
// Admin creates two configs: one enabled, one disabled.
|
||||
_ = createMCPServerConfig(t, adminClient, "enabled-server", true)
|
||||
_ = createMCPServerConfig(t, adminClient, "disabled-server", false)
|
||||
|
||||
// Admin sees both.
|
||||
adminConfigs, err := adminClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, adminConfigs, 2)
|
||||
|
||||
// Regular user sees only the enabled one.
|
||||
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberConfigs, 1)
|
||||
require.Equal(t, "enabled-server", memberConfigs[0].Slug)
|
||||
}
|
||||
|
||||
// TestMCPServerConfigsSecretsNeverLeaked is a load-bearing test that
|
||||
// ensures secret fields (OAuth2 client secret, API key value, custom
|
||||
// headers) are never present in API responses for any caller. If this
|
||||
// test fails, it means a code change accidentally started exposing
|
||||
// secrets. See: https://github.com/coder/coder/pull/23227#discussion_r2959461109
|
||||
func TestMCPServerConfigsSecretsNeverLeaked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newMCPClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
// Create a config with ALL secret fields populated.
|
||||
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Secrets Test",
|
||||
Slug: "secrets-test",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/secrets",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "client-id-secret-test",
|
||||
OAuth2ClientSecret: "THIS-IS-A-SECRET-VALUE",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
OAuth2Scopes: "read write",
|
||||
APIKeyHeader: "X-Api-Key",
|
||||
APIKeyValue: "THIS-IS-A-SECRET-API-KEY",
|
||||
CustomHeaders: map[string]string{"X-Custom": "THIS-IS-A-SECRET-HEADER"},
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The sentinel values we must never see in any JSON response.
|
||||
secrets := []string{
|
||||
"THIS-IS-A-SECRET-VALUE",
|
||||
"THIS-IS-A-SECRET-API-KEY",
|
||||
"THIS-IS-A-SECRET-HEADER",
|
||||
}
|
||||
|
||||
assertNoSecrets := func(t *testing.T, label string, v interface{}) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
jsonStr := string(data)
|
||||
for _, secret := range secrets {
|
||||
assert.False(t, strings.Contains(jsonStr, secret),
|
||||
"%s: JSON response contains secret %q", label, secret)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the create response doesn't leak secrets.
|
||||
assertNoSecrets(t, "admin create response", created)
|
||||
|
||||
// Verify boolean indicators are set correctly.
|
||||
require.True(t, created.HasOAuth2Secret, "HasOAuth2Secret should be true")
|
||||
require.True(t, created.HasAPIKey, "HasAPIKey should be true")
|
||||
require.True(t, created.HasCustomHeaders, "HasCustomHeaders should be true")
|
||||
|
||||
// Admin list endpoint.
|
||||
adminConfigs, err := adminClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, adminConfigs)
|
||||
for _, cfg := range adminConfigs {
|
||||
assertNoSecrets(t, "admin list", cfg)
|
||||
}
|
||||
|
||||
// Admin get-by-ID endpoint.
|
||||
adminSingle, err := adminClient.MCPServerConfigByID(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
assertNoSecrets(t, "admin get-by-id", adminSingle)
|
||||
|
||||
// Non-admin list endpoint.
|
||||
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, memberConfigs)
|
||||
for _, cfg := range memberConfigs {
|
||||
assertNoSecrets(t, "member list", cfg)
|
||||
// Non-admin should also not see admin-only fields.
|
||||
assert.Empty(t, cfg.OAuth2ClientID, "member should not see OAuth2ClientID")
|
||||
assert.Empty(t, cfg.OAuth2AuthURL, "member should not see OAuth2AuthURL")
|
||||
assert.Empty(t, cfg.OAuth2TokenURL, "member should not see OAuth2TokenURL")
|
||||
assert.Empty(t, cfg.APIKeyHeader, "member should not see APIKeyHeader")
|
||||
assert.Empty(t, cfg.OAuth2Scopes, "member should not see OAuth2Scopes")
|
||||
assert.Empty(t, cfg.URL, "member should not see URL")
|
||||
assert.Empty(t, cfg.Transport, "member should not see Transport")
|
||||
}
|
||||
|
||||
// Non-admin get-by-ID endpoint.
|
||||
memberSingle, err := memberClient.MCPServerConfigByID(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
assertNoSecrets(t, "member get-by-id", memberSingle)
|
||||
assert.Empty(t, memberSingle.OAuth2ClientID, "member should not see OAuth2ClientID")
|
||||
assert.Empty(t, memberSingle.OAuth2AuthURL, "member should not see OAuth2AuthURL")
|
||||
assert.Empty(t, memberSingle.OAuth2TokenURL, "member should not see OAuth2TokenURL")
|
||||
assert.Empty(t, memberSingle.OAuth2Scopes, "member should not see OAuth2Scopes")
|
||||
assert.Empty(t, memberSingle.APIKeyHeader, "member should not see APIKeyHeader")
|
||||
assert.Empty(t, memberSingle.URL, "member should not see URL")
|
||||
assert.Empty(t, memberSingle.Transport, "member should not see Transport")
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsAuthConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newMCPClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
// Create an oauth2 server config (enabled).
|
||||
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "OAuth Server",
|
||||
Slug: "oauth-server",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/oauth",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "cid",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Regular user lists configs — auth_connected should be false
|
||||
// because no token has been stored.
|
||||
memberConfigs, err := memberClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberConfigs, 1)
|
||||
require.Equal(t, created.ID, memberConfigs[0].ID)
|
||||
require.False(t, memberConfigs[0].AuthConnected)
|
||||
|
||||
// Also create a non-oauth server. It should report
|
||||
// auth_connected=true because no auth is needed.
|
||||
_ = createMCPServerConfig(t, adminClient, "no-auth-server", true)
|
||||
memberConfigs, err = memberClient.MCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, memberConfigs, 2)
|
||||
for _, cfg := range memberConfigs {
|
||||
if cfg.AuthType == "none" {
|
||||
require.True(t, cfg.AuthConnected)
|
||||
} else {
|
||||
require.False(t, cfg.AuthConnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsAvailability(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
validValues := []string{"force_on", "default_on", "default_off"}
|
||||
for _, av := range validValues {
|
||||
av := av
|
||||
t.Run(av, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
created, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Server " + av,
|
||||
Slug: "server-" + av,
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/" + av,
|
||||
AuthType: "none",
|
||||
Availability: av,
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, av, created.Availability)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("InvalidAvailability", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Bad Availability",
|
||||
Slug: "bad-avail",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/bad",
|
||||
AuthType: "none",
|
||||
Availability: "always_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsUniqueSlug(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "First",
|
||||
Slug: "test-server",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/first",
|
||||
AuthType: "none",
|
||||
Availability: "default_off",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to create another config with the same slug.
|
||||
_, err = client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "Second",
|
||||
Slug: "test-server",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/second",
|
||||
AuthType: "none",
|
||||
Availability: "default_off",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
}
|
||||
|
||||
func TestMCPServerConfigsOAuth2Disconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newMCPClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
created, err := adminClient.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
||||
DisplayName: "OAuth Disconnect Test",
|
||||
Slug: "oauth-disconnect",
|
||||
Transport: "streamable_http",
|
||||
URL: "https://mcp.example.com/oauth-disc",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "cid",
|
||||
OAuth2AuthURL: "https://auth.example.com/authorize",
|
||||
OAuth2TokenURL: "https://auth.example.com/token",
|
||||
Availability: "default_on",
|
||||
Enabled: true,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disconnect should succeed even when no token exists (idempotent).
|
||||
err = memberClient.MCPServerOAuth2Disconnect(ctx, created.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestChatWithMCPServerIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newMCPClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create the chat model config required for creating a chat.
|
||||
_ = createChatModelConfigForMCP(t, client)
|
||||
|
||||
// Create an enabled MCP server config.
|
||||
mcpConfig := createMCPServerConfig(t, client, "chat-mcp-server", true)
|
||||
|
||||
// Create a chat referencing the MCP server.
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello with mcp server",
|
||||
},
|
||||
},
|
||||
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, uuid.Nil, chat.ID)
|
||||
require.Contains(t, chat.MCPServerIDs, mcpConfig.ID)
|
||||
|
||||
// Fetch the chat and verify the MCP server IDs persist.
|
||||
fetched, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, fetched.MCPServerIDs, mcpConfig.ID)
|
||||
}
|
||||
|
||||
// createChatModelConfigForMCP sets up a chat provider and model
|
||||
// config so that CreateChat succeeds. This mirrors the helper in
|
||||
// chats_test.go but is defined here to avoid coupling.
|
||||
func createChatModelConfigForMCP(t testing.TB, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return modelConfig
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func (api *API) PrimaryRegion(ctx context.Context) (codersdk.Region, error) {
|
||||
ID: deploymentID,
|
||||
Name: "primary",
|
||||
DisplayName: proxy.DisplayName,
|
||||
IconURL: proxy.IconUrl,
|
||||
IconURL: proxy.IconURL,
|
||||
Healthy: true,
|
||||
PathAppURL: api.AccessURL.String(),
|
||||
WildcardHostname: appurl.SubdomainAppHost(api.AppHostname, api.AccessURL),
|
||||
|
||||
@@ -49,6 +49,7 @@ type Chat struct {
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
Archived bool `json:"archived"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat.
|
||||
@@ -267,6 +268,7 @@ type CreateChatRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
}
|
||||
|
||||
// UpdateChatRequest is the request to update a chat.
|
||||
@@ -279,6 +281,7 @@ type UpdateChatRequest struct {
|
||||
type CreateChatMessageRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
}
|
||||
|
||||
// EditChatMessageRequest is the request to edit a user message in a chat.
|
||||
|
||||
+191
@@ -0,0 +1,191 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MCPServerOAuth2ConnectURL returns the URL the user should visit to
|
||||
// start the OAuth2 flow for an MCP server. The frontend opens this
|
||||
// in a new window/popup.
|
||||
func (c *Client) MCPServerOAuth2ConnectURL(id uuid.UUID) string {
|
||||
return fmt.Sprintf("%s/api/experimental/mcp/servers/%s/oauth2/connect", c.URL.String(), id)
|
||||
}
|
||||
|
||||
// MCPServerOAuth2Disconnect removes the user's OAuth2 token for an
|
||||
// MCP server.
|
||||
func (c *Client) MCPServerOAuth2Disconnect(ctx context.Context, id uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s/oauth2/disconnect", id), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPServerConfig represents an admin-configured MCP server.
|
||||
type MCPServerConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Slug string `json:"slug"`
|
||||
Description string `json:"description"`
|
||||
IconURL string `json:"icon_url"`
|
||||
|
||||
Transport string `json:"transport"` // "streamable_http" or "sse"
|
||||
URL string `json:"url"`
|
||||
|
||||
AuthType string `json:"auth_type"` // "none", "oauth2", "api_key", "custom_headers"
|
||||
|
||||
// OAuth2 fields (only populated for admins).
|
||||
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
|
||||
HasOAuth2Secret bool `json:"has_oauth2_secret"`
|
||||
OAuth2AuthURL string `json:"oauth2_auth_url,omitempty"`
|
||||
OAuth2TokenURL string `json:"oauth2_token_url,omitempty"`
|
||||
OAuth2Scopes string `json:"oauth2_scopes,omitempty"`
|
||||
|
||||
// API key fields (only populated for admins).
|
||||
APIKeyHeader string `json:"api_key_header,omitempty"`
|
||||
HasAPIKey bool `json:"has_api_key"`
|
||||
|
||||
HasCustomHeaders bool `json:"has_custom_headers"`
|
||||
|
||||
// Tool governance.
|
||||
ToolAllowList []string `json:"tool_allow_list"`
|
||||
ToolDenyList []string `json:"tool_deny_list"`
|
||||
|
||||
// Availability policy set by admin.
|
||||
Availability string `json:"availability"` // "force_on", "default_on", "default_off"
|
||||
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
|
||||
// Per-user state (populated for non-admin requests).
|
||||
AuthConnected bool `json:"auth_connected"`
|
||||
}
|
||||
|
||||
// CreateMCPServerConfigRequest is the request to create a new MCP server config.
|
||||
type CreateMCPServerConfigRequest struct {
|
||||
DisplayName string `json:"display_name" validate:"required"`
|
||||
Slug string `json:"slug" validate:"required"`
|
||||
Description string `json:"description"`
|
||||
IconURL string `json:"icon_url"`
|
||||
|
||||
Transport string `json:"transport" validate:"required,oneof=streamable_http sse"`
|
||||
URL string `json:"url" validate:"required,url"`
|
||||
|
||||
AuthType string `json:"auth_type" validate:"required,oneof=none oauth2 api_key custom_headers"`
|
||||
OAuth2ClientID string `json:"oauth2_client_id,omitempty"`
|
||||
OAuth2ClientSecret string `json:"oauth2_client_secret,omitempty"`
|
||||
OAuth2AuthURL string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
|
||||
OAuth2TokenURL string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"`
|
||||
OAuth2Scopes string `json:"oauth2_scopes,omitempty"`
|
||||
APIKeyHeader string `json:"api_key_header,omitempty"`
|
||||
APIKeyValue string `json:"api_key_value,omitempty"`
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"`
|
||||
|
||||
ToolAllowList []string `json:"tool_allow_list,omitempty"`
|
||||
ToolDenyList []string `json:"tool_deny_list,omitempty"`
|
||||
|
||||
Availability string `json:"availability" validate:"required,oneof=force_on default_on default_off"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateMCPServerConfigRequest is the request to update an MCP server config.
|
||||
type UpdateMCPServerConfigRequest struct {
|
||||
DisplayName *string `json:"display_name,omitempty"`
|
||||
Slug *string `json:"slug,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
IconURL *string `json:"icon_url,omitempty"`
|
||||
|
||||
Transport *string `json:"transport,omitempty" validate:"omitempty,oneof=streamable_http sse"`
|
||||
URL *string `json:"url,omitempty" validate:"omitempty,url"`
|
||||
|
||||
AuthType *string `json:"auth_type,omitempty" validate:"omitempty,oneof=none oauth2 api_key custom_headers"`
|
||||
OAuth2ClientID *string `json:"oauth2_client_id,omitempty"`
|
||||
OAuth2ClientSecret *string `json:"oauth2_client_secret,omitempty"`
|
||||
OAuth2AuthURL *string `json:"oauth2_auth_url,omitempty" validate:"omitempty,url"`
|
||||
OAuth2TokenURL *string `json:"oauth2_token_url,omitempty" validate:"omitempty,url"`
|
||||
OAuth2Scopes *string `json:"oauth2_scopes,omitempty"`
|
||||
APIKeyHeader *string `json:"api_key_header,omitempty"`
|
||||
APIKeyValue *string `json:"api_key_value,omitempty"`
|
||||
CustomHeaders *map[string]string `json:"custom_headers,omitempty"`
|
||||
|
||||
ToolAllowList *[]string `json:"tool_allow_list,omitempty"`
|
||||
ToolDenyList *[]string `json:"tool_deny_list,omitempty"`
|
||||
|
||||
Availability *string `json:"availability,omitempty" validate:"omitempty,oneof=force_on default_on default_off"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) MCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/mcp/servers", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var configs []MCPServerConfig
|
||||
return configs, json.NewDecoder(res.Body).Decode(&configs)
|
||||
}
|
||||
|
||||
func (c *Client) MCPServerConfigByID(ctx context.Context, id uuid.UUID) (MCPServerConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil)
|
||||
if err != nil {
|
||||
return MCPServerConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return MCPServerConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
var config MCPServerConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
func (c *Client) CreateMCPServerConfig(ctx context.Context, req CreateMCPServerConfigRequest) (MCPServerConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, "/api/experimental/mcp/servers", req)
|
||||
if err != nil {
|
||||
return MCPServerConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return MCPServerConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
var config MCPServerConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
func (c *Client) UpdateMCPServerConfig(ctx context.Context, id uuid.UUID, req UpdateMCPServerConfigRequest) (MCPServerConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), req)
|
||||
if err != nil {
|
||||
return MCPServerConfig{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return MCPServerConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
var config MCPServerConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteMCPServerConfig(ctx context.Context, id uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/mcp/servers/%s", id), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -204,7 +204,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw
|
||||
|
||||
args := database.UpsertDefaultProxyParams{
|
||||
DisplayName: req.DisplayName,
|
||||
IconUrl: req.Icon,
|
||||
IconURL: req.Icon,
|
||||
}
|
||||
if req.DisplayName == "" || req.Icon == "" {
|
||||
// If the user has not specified an update value, use the existing value.
|
||||
@@ -217,7 +217,7 @@ func (api *API) patchPrimaryWorkspaceProxy(req codersdk.PatchWorkspaceProxy, rw
|
||||
args.DisplayName = existing.DisplayName
|
||||
}
|
||||
if req.Icon == "" {
|
||||
args.IconUrl = existing.IconUrl
|
||||
args.IconURL = existing.IconURL
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -471,6 +471,201 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// decryptMCPServerConfig decrypts all encrypted fields on a
|
||||
// single MCPServerConfig in place.
|
||||
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
||||
if err := db.decryptField(&cfg.OAuth2ClientSecret, cfg.OAuth2ClientSecretKeyID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := db.decryptField(&cfg.APIKeyValue, cfg.APIKeyValueKeyID); err != nil {
|
||||
return err
|
||||
}
|
||||
return db.decryptField(&cfg.CustomHeaders, cfg.CustomHeadersKeyID)
|
||||
}
|
||||
|
||||
// decryptMCPServerUserToken decrypts all encrypted fields on a
|
||||
// single MCPServerUserToken in place.
|
||||
func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) error {
|
||||
if err := db.decryptField(&tok.AccessToken, tok.AccessTokenKeyID); err != nil {
|
||||
return err
|
||||
}
|
||||
return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
cfg, err := db.Store.GetMCPServerConfigByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
cfg, err := db.Store.GetMCPServerConfigBySlug(ctx, slug)
|
||||
if err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
cfgs, err := db.Store.GetMCPServerConfigs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range cfgs {
|
||||
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
cfgs, err := db.Store.GetMCPServerConfigsByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range cfgs {
|
||||
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
cfgs, err := db.Store.GetEnabledMCPServerConfigs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range cfgs {
|
||||
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
cfgs, err := db.Store.GetForcedMCPServerConfigs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range cfgs {
|
||||
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
tok, err := db.Store.GetMCPServerUserToken(ctx, arg)
|
||||
if err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
toks, err := db.Store.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range toks {
|
||||
if err := db.decryptMCPServerUserToken(&toks[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return toks, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
||||
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.APIKeyValue) == "" {
|
||||
params.APIKeyValueKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.CustomHeaders) == "" {
|
||||
params.CustomHeadersKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
|
||||
cfg, err := db.Store.InsertMCPServerConfig(ctx, params)
|
||||
if err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
||||
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.APIKeyValue) == "" {
|
||||
params.APIKeyValueKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.CustomHeaders) == "" {
|
||||
params.CustomHeadersKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
|
||||
cfg, err := db.Store.UpdateMCPServerConfig(ctx, params)
|
||||
if err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
if strings.TrimSpace(params.AccessToken) == "" {
|
||||
params.AccessTokenKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.AccessToken, ¶ms.AccessTokenKeyID); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.RefreshToken) == "" {
|
||||
params.RefreshTokenKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.RefreshToken, ¶ms.RefreshTokenKeyID); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
|
||||
tok, err := db.Store.UpsertMCPServerUserToken(ctx, params)
|
||||
if err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
|
||||
// If no cipher is loaded, then we can't encrypt anything!
|
||||
if db.ciphers == nil || db.primaryCipherDigest == "" {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
@@ -878,3 +879,301 @@ func fakeBase64RandomData(t *testing.T, n int) string {
|
||||
require.NoError(t, err)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// requireMCPServerConfigDecrypted verifies all encrypted fields on an
|
||||
// MCPServerConfig match the expected plaintext values and carry the
|
||||
// correct key-ID.
|
||||
func requireMCPServerConfigDecrypted(
|
||||
t *testing.T,
|
||||
cfg database.MCPServerConfig,
|
||||
ciphers []Cipher,
|
||||
wantSecret, wantAPIKey, wantHeaders string,
|
||||
) {
|
||||
t.Helper()
|
||||
require.Equal(t, wantSecret, cfg.OAuth2ClientSecret)
|
||||
require.Equal(t, wantAPIKey, cfg.APIKeyValue)
|
||||
require.Equal(t, wantHeaders, cfg.CustomHeaders)
|
||||
require.Equal(t, ciphers[0].HexDigest(), cfg.OAuth2ClientSecretKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), cfg.APIKeyValueKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), cfg.CustomHeadersKeyID.String)
|
||||
}
|
||||
|
||||
// requireMCPServerConfigRawEncrypted reads the config from the raw
|
||||
// (unwrapped) store and asserts every secret field is encrypted.
|
||||
func requireMCPServerConfigRawEncrypted(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
rawDB database.Store,
|
||||
cfgID uuid.UUID,
|
||||
ciphers []Cipher,
|
||||
wantSecret, wantAPIKey, wantHeaders string,
|
||||
) {
|
||||
t.Helper()
|
||||
raw, err := rawDB.GetMCPServerConfigByID(ctx, cfgID)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], raw.OAuth2ClientSecret, wantSecret)
|
||||
requireEncryptedEquals(t, ciphers[0], raw.APIKeyValue, wantAPIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], raw.CustomHeaders, wantHeaders)
|
||||
}
|
||||
|
||||
func TestMCPServerConfigs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
//nolint:gosec // test credentials
|
||||
oauthSecret = "my-oauth-secret"
|
||||
apiKeyValue = "my-api-key"
|
||||
customHeaders = `{"X-Custom":"header-value"}`
|
||||
)
|
||||
// insertConfig is a small helper that creates a user and an MCP
|
||||
// server config through the encrypted store, returning both.
|
||||
insertConfig := func(t *testing.T, crypt *dbCrypt, ciphers []Cipher) database.MCPServerConfig {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
|
||||
DisplayName: "Test MCP Server",
|
||||
Slug: "test-mcp-" + uuid.New().String()[:8],
|
||||
Description: "test description",
|
||||
Url: "https://mcp.example.com",
|
||||
Transport: "streamable_http",
|
||||
AuthType: "oauth2",
|
||||
OAuth2ClientID: "client-id",
|
||||
OAuth2ClientSecret: oauthSecret,
|
||||
APIKeyValue: apiKeyValue,
|
||||
CustomHeaders: customHeaders,
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
Availability: "force_on",
|
||||
Enabled: true,
|
||||
CreatedBy: user.ID,
|
||||
UpdatedBy: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireMCPServerConfigDecrypted(t, cfg, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("InsertMCPServerConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerConfigByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
got, err := crypt.GetMCPServerConfigByID(ctx, cfg.ID)
|
||||
require.NoError(t, err)
|
||||
requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerConfigBySlug", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
got, err := crypt.GetMCPServerConfigBySlug(ctx, cfg.Slug)
|
||||
require.NoError(t, err)
|
||||
requireMCPServerConfigDecrypted(t, got, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerConfigs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
cfgs, err := crypt.GetMCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfgs, 1)
|
||||
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerConfigsByIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
cfgs, err := crypt.GetMCPServerConfigsByIDs(ctx, []uuid.UUID{cfg.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfgs, 1)
|
||||
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetEnabledMCPServerConfigs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
cfgs, err := crypt.GetEnabledMCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfgs, 1)
|
||||
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("GetForcedMCPServerConfigs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
cfgs, err := crypt.GetForcedMCPServerConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cfgs, 1)
|
||||
requireMCPServerConfigDecrypted(t, cfgs[0], ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, oauthSecret, apiKeyValue, customHeaders)
|
||||
})
|
||||
|
||||
t.Run("UpdateMCPServerConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg := insertConfig(t, crypt, ciphers)
|
||||
|
||||
const (
|
||||
//nolint:gosec // test credential
|
||||
newSecret = "updated-oauth-secret"
|
||||
newAPIKey = "updated-api-key"
|
||||
newHeaders = `{"X-New":"new-value"}`
|
||||
)
|
||||
updated, err := crypt.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{
|
||||
ID: cfg.ID,
|
||||
DisplayName: cfg.DisplayName,
|
||||
Slug: cfg.Slug,
|
||||
Description: cfg.Description,
|
||||
Url: cfg.Url,
|
||||
Transport: cfg.Transport,
|
||||
AuthType: cfg.AuthType,
|
||||
OAuth2ClientID: cfg.OAuth2ClientID,
|
||||
OAuth2ClientSecret: newSecret,
|
||||
APIKeyValue: newAPIKey,
|
||||
CustomHeaders: newHeaders,
|
||||
ToolAllowList: cfg.ToolAllowList,
|
||||
ToolDenyList: cfg.ToolDenyList,
|
||||
Availability: cfg.Availability,
|
||||
Enabled: cfg.Enabled,
|
||||
UpdatedBy: cfg.CreatedBy.UUID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireMCPServerConfigDecrypted(t, updated, ciphers, newSecret, newAPIKey, newHeaders)
|
||||
requireMCPServerConfigRawEncrypted(ctx, t, db, cfg.ID, ciphers, newSecret, newAPIKey, newHeaders)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMCPServerUserTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
accessToken = "access-token-value"
|
||||
refreshToken = "refresh-token-value"
|
||||
)
|
||||
|
||||
// insertConfigAndToken creates a user, an MCP server config, and a
|
||||
// user token through the encrypted store.
|
||||
insertConfigAndToken := func(
|
||||
t *testing.T,
|
||||
crypt *dbCrypt,
|
||||
ciphers []Cipher,
|
||||
) (database.MCPServerConfig, database.MCPServerUserToken) {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
cfg, err := crypt.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
|
||||
DisplayName: "Token Test MCP",
|
||||
Slug: "tok-mcp-" + uuid.New().String()[:8],
|
||||
Url: "https://mcp.example.com",
|
||||
Transport: "streamable_http",
|
||||
AuthType: "oauth2",
|
||||
ToolAllowList: []string{},
|
||||
ToolDenyList: []string{},
|
||||
Availability: "default_off",
|
||||
Enabled: true,
|
||||
CreatedBy: user.ID,
|
||||
UpdatedBy: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := crypt.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: user.ID,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, accessToken, tok.AccessToken)
|
||||
require.Equal(t, refreshToken, tok.RefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), tok.AccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), tok.RefreshTokenKeyID.String)
|
||||
return cfg, tok
|
||||
}
|
||||
|
||||
t.Run("UpsertMCPServerUserToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
|
||||
|
||||
// Verify the raw DB values are encrypted.
|
||||
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: tok.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerUserToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
|
||||
|
||||
got, err := crypt.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: tok.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, accessToken, got.AccessToken)
|
||||
require.Equal(t, refreshToken, got.RefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.AccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.RefreshTokenKeyID.String)
|
||||
|
||||
// Raw values must be encrypted.
|
||||
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: tok.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
|
||||
})
|
||||
|
||||
t.Run("GetMCPServerUserTokensByUserID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
cfg, tok := insertConfigAndToken(t, crypt, ciphers)
|
||||
|
||||
toks, err := crypt.GetMCPServerUserTokensByUserID(ctx, tok.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, toks, 1)
|
||||
require.Equal(t, accessToken, toks[0].AccessToken)
|
||||
require.Equal(t, refreshToken, toks[0].RefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), toks[0].AccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), toks[0].RefreshTokenKeyID.String)
|
||||
|
||||
// Raw values must be encrypted.
|
||||
rawTok, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: cfg.ID,
|
||||
UserID: tok.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.AccessToken, accessToken)
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,6 +73,7 @@ const makeChat = (
|
||||
id,
|
||||
owner_id: "owner-1",
|
||||
last_model_config_id: "model-1",
|
||||
mcp_server_ids: [],
|
||||
title: `Chat ${id}`,
|
||||
status: "running",
|
||||
created_at: "2025-01-01T00:00:00.000Z",
|
||||
|
||||
Generated
+100
@@ -1070,6 +1070,7 @@ export interface Chat {
|
||||
readonly created_at: string;
|
||||
readonly updated_at: string;
|
||||
readonly archived: boolean;
|
||||
readonly mcp_server_ids: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/deployment.go
|
||||
@@ -2076,6 +2077,7 @@ export interface ConvertLoginRequest {
|
||||
export interface CreateChatMessageRequest {
|
||||
readonly content: readonly ChatInputPart[];
|
||||
readonly model_config_id?: string;
|
||||
readonly mcp_server_ids?: string[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -2123,6 +2125,7 @@ export interface CreateChatRequest {
|
||||
readonly content: readonly ChatInputPart[];
|
||||
readonly workspace_id?: string;
|
||||
readonly model_config_id?: string;
|
||||
readonly mcp_server_ids?: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/users.go
|
||||
@@ -2163,6 +2166,32 @@ export interface CreateGroupRequest {
|
||||
readonly quota_allowance: number;
|
||||
}
|
||||
|
||||
// From codersdk/mcp.go
|
||||
/**
|
||||
* CreateMCPServerConfigRequest is the request to create a new MCP server config.
|
||||
*/
|
||||
export interface CreateMCPServerConfigRequest {
|
||||
readonly display_name: string;
|
||||
readonly slug: string;
|
||||
readonly description: string;
|
||||
readonly icon_url: string;
|
||||
readonly transport: string;
|
||||
readonly url: string;
|
||||
readonly auth_type: string;
|
||||
readonly oauth2_client_id?: string;
|
||||
readonly oauth2_client_secret?: string;
|
||||
readonly oauth2_auth_url?: string;
|
||||
readonly oauth2_token_url?: string;
|
||||
readonly oauth2_scopes?: string;
|
||||
readonly api_key_header?: string;
|
||||
readonly api_key_value?: string;
|
||||
readonly custom_headers?: Record<string, string>;
|
||||
readonly tool_allow_list?: readonly string[];
|
||||
readonly tool_deny_list?: readonly string[];
|
||||
readonly availability: string;
|
||||
readonly enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/organizations.go
|
||||
export interface CreateOrganizationRequest {
|
||||
readonly name: string;
|
||||
@@ -3692,6 +3721,51 @@ export interface LoginWithPasswordResponse {
|
||||
readonly session_token: string;
|
||||
}
|
||||
|
||||
// From codersdk/mcp.go
|
||||
/**
|
||||
* MCPServerConfig represents an admin-configured MCP server.
|
||||
*/
|
||||
export interface MCPServerConfig {
|
||||
readonly id: string;
|
||||
readonly display_name: string;
|
||||
readonly slug: string;
|
||||
readonly description: string;
|
||||
readonly icon_url: string;
|
||||
readonly transport: string; // "streamable_http" or "sse"
|
||||
readonly url: string;
|
||||
readonly auth_type: string; // "none", "oauth2", "api_key", "custom_headers"
|
||||
/**
|
||||
* OAuth2 fields (only populated for admins).
|
||||
*/
|
||||
readonly oauth2_client_id?: string;
|
||||
readonly has_oauth2_secret: boolean;
|
||||
readonly oauth2_auth_url?: string;
|
||||
readonly oauth2_token_url?: string;
|
||||
readonly oauth2_scopes?: string;
|
||||
/**
|
||||
* API key fields (only populated for admins).
|
||||
*/
|
||||
readonly api_key_header?: string;
|
||||
readonly has_api_key: boolean;
|
||||
readonly has_custom_headers: boolean;
|
||||
/**
|
||||
* Tool governance.
|
||||
*/
|
||||
readonly tool_allow_list: readonly string[];
|
||||
readonly tool_deny_list: readonly string[];
|
||||
/**
|
||||
* Availability policy set by admin.
|
||||
*/
|
||||
readonly availability: string; // "force_on", "default_on", "default_off"
|
||||
readonly enabled: boolean;
|
||||
readonly created_at: string;
|
||||
readonly updated_at: string;
|
||||
/**
|
||||
* Per-user state (populated for non-admin requests).
|
||||
*/
|
||||
readonly auth_connected: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/provisionerdaemons.go
|
||||
/**
|
||||
* MatchedProvisioners represents the number of provisioner daemons
|
||||
@@ -6844,6 +6918,32 @@ export interface UpdateInboxNotificationReadStatusResponse {
|
||||
readonly unread_count: number;
|
||||
}
|
||||
|
||||
// From codersdk/mcp.go
|
||||
/**
|
||||
* UpdateMCPServerConfigRequest is the request to update an MCP server config.
|
||||
*/
|
||||
export interface UpdateMCPServerConfigRequest {
|
||||
readonly display_name?: string;
|
||||
readonly slug?: string;
|
||||
readonly description?: string;
|
||||
readonly icon_url?: string;
|
||||
readonly transport?: string;
|
||||
readonly url?: string;
|
||||
readonly auth_type?: string;
|
||||
readonly oauth2_client_id?: string;
|
||||
readonly oauth2_client_secret?: string;
|
||||
readonly oauth2_auth_url?: string;
|
||||
readonly oauth2_token_url?: string;
|
||||
readonly oauth2_scopes?: string;
|
||||
readonly api_key_header?: string;
|
||||
readonly api_key_value?: string;
|
||||
readonly custom_headers?: Record<string, string>;
|
||||
readonly tool_allow_list?: string[];
|
||||
readonly tool_deny_list?: string[];
|
||||
readonly availability?: string;
|
||||
readonly enabled?: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/notifications.go
|
||||
export interface UpdateNotificationTemplateMethod {
|
||||
readonly method?: string;
|
||||
|
||||
@@ -106,6 +106,7 @@ const baseChatFields = {
|
||||
owner_id: "owner-id",
|
||||
workspace_id: mockWorkspace.id,
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
created_at: "2026-02-18T00:00:00.000Z",
|
||||
updated_at: "2026-02-18T00:00:00.000Z",
|
||||
archived: false,
|
||||
|
||||
@@ -183,6 +183,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({
|
||||
id: chatID,
|
||||
owner_id: "owner-1",
|
||||
last_model_config_id: "model-1",
|
||||
mcp_server_ids: [],
|
||||
title: "test",
|
||||
status: "running",
|
||||
created_at: "2025-01-01T00:00:00.000Z",
|
||||
|
||||
@@ -52,6 +52,7 @@ export const WithParentChat: Story = {
|
||||
id: "parent-chat-1",
|
||||
owner_id: "owner-id",
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
title: "Set up CI/CD pipeline",
|
||||
status: "completed",
|
||||
last_error: null,
|
||||
|
||||
@@ -36,6 +36,7 @@ const buildChat = (overrides: Partial<TypesGen.Chat> = {}): TypesGen.Chat => ({
|
||||
title: "Help me refactor",
|
||||
status: "completed",
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -113,6 +113,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
title: "Agent",
|
||||
status: "completed",
|
||||
last_model_config_id: defaultModelConfigs[0].id,
|
||||
mcp_server_ids: [],
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -40,6 +40,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
title: "Agent",
|
||||
status: "completed",
|
||||
last_model_config_id: defaultModelConfigs[0].id,
|
||||
mcp_server_ids: [],
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -63,6 +63,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
last_error: null,
|
||||
mcp_server_ids: [],
|
||||
...overrides,
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user