mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
d8ff67fb68
## Summary
Adds the database schema, API endpoints, SDK types, and encryption
wrappers for admin-managed MCP (Model Context Protocol) server
configurations that chatd can consume. This is the backend foundation
for allowing external MCP tools (Sentry, Linear, GitHub, etc.) to be
used during AI chat sessions.
## Database
Two new tables:
- **`mcp_server_configs`**: Admin-managed server definitions with URL,
transport (Streamable HTTP / SSE), auth config (none / OAuth2 / API key
/ custom headers), tool allow/deny lists, and an availability policy
(`force_on` / `default_on` / `default_off`). Includes CHECK constraints
on transport, auth_type, and availability values.
- **`mcp_server_user_tokens`**: Per-user OAuth2 tokens for servers
requiring individual authentication. Cascades on user/config deletion.
New column on `chats` table:
- **`mcp_server_ids UUID[]`**: Per-chat MCP server selection, following
the same pattern as `model_config_id` — passed at chat creation,
changeable per-message with nil-means-no-change semantics.
## API Endpoints
All routes are under `/api/experimental/mcp/servers/` and gated behind
the `agents` experiment.
**Admin endpoints** (`ResourceDeploymentConfig` auth):
- `POST /` — Create MCP server config
- `PATCH /{id}` — Update MCP server config (full-replace)
- `DELETE /{id}` — Delete MCP server config
**Authenticated endpoints** (all users, enabled servers only for
non-admins):
- `GET /` — List configs (admins see all, members see enabled-only with
admin fields redacted)
- `GET /{id}` — Get config by ID (with `auth_connected` populated
per-user)
**OAuth2 per-user auth flow:**
- `GET /{id}/oauth2/connect` — Initiate OAuth2 flow (state cookie CSRF
protection)
- `GET /{id}/oauth2/callback` — Handle OAuth2 callback, store tokens
- `DELETE /{id}/oauth2/disconnect` — Remove stored OAuth2 tokens
## Security
- **Secrets never returned**: `OAuth2ClientSecret`, `APIKeyValue`, and
`CustomHeaders` are never in API responses — only boolean indicators
(`has_oauth2_secret`, `has_api_key`, `has_custom_headers`).
- **Field redaction for non-admins**: `convertMCPServerConfigRedacted`
strips `OAuth2ClientID`, auth URLs, scopes, and `APIKeyHeader` from
non-admin responses.
- **dbcrypt encryption at rest**: All 5 secret fields use `dbcrypt_keys`
encryption with full encrypt-on-write / decrypt-on-read wrappers (11
dbcrypt method overrides + 2 helpers), following the same pattern as
`chat_providers.api_key`.
- **OAuth2 CSRF protection**: State parameter stored in `HttpOnly`
cookie with `HTTPCookies.Apply()` for correct `Secure`/`SameSite` behind
TLS-terminating proxies.
- **dbauthz authorization**: All 18 querier methods have authorization
wrappers. Read operations use `ActionRead`, write operations use
`ActionUpdate` on `ResourceDeploymentConfig`.
## Governance Model
| Control | Implementation |
|---------|---------------|
| **Global kill switch** | `enabled` defaults to `false` |
| **Availability policy** | `force_on` (always injected), `default_on`
(pre-selected), `default_off` (opt-in) |
| **Per-chat selection** | `mcp_server_ids` on `CreateChatRequest` /
`CreateChatMessageRequest` |
| **Auth gate** | OAuth2 servers require per-user auth before tools are
injected |
| **Tool-level allow/deny** | Arrays on `mcp_server_configs` for
granular tool filtering |
| **Secrets encrypted at rest** | Uses `dbcrypt_keys` (same pattern as
`chat_providers.api_key`) |
## Tests
8 test functions covering:
- Full CRUD lifecycle (create, list, update, delete)
- Non-admin visibility filtering (enabled-only, field redaction)
- `auth_connected` population for OAuth2 vs non-OAuth2 servers
- Availability policy validation (valid values + invalid rejection)
- Unique slug enforcement (409 Conflict)
- OAuth2 disconnect idempotency
- Chat creation with `mcp_server_ids` persistence
## Known Limitations (Deferred)
These are documented and intentional for an experimental feature:
- **Audit logging** not yet wired — will add when feature stabilizes
- **Cross-field validation** (e.g., OAuth2 fields required when
`auth_type=oauth2`) — admin-only endpoint, will add when stabilizing
- **`force_on` auto-injection** — query exists but not yet wired into
chatd tool injection (follow-up)
- **Additional test coverage** — 403 auth tests, GET-by-ID tests,
callback CSRF tests planned for follow-up
## What's NOT in this PR
- Frontend UI (admin panel + chat picker)
- Actual MCP client connections (`chatd/chatmcp/` manager)
- Tool injection into `chatloop/`
780 lines
26 KiB
Go
780 lines
26 KiB
Go
package dbcrypt
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
)
|
|
|
|
// testValue is the value that is stored in dbcrypt_keys.test.
|
|
// This is used to determine if the key is valid.
|
|
const testValue = "coder"
|
|
|
|
var (
|
|
b64encode = base64.StdEncoding.EncodeToString
|
|
b64decode = base64.StdEncoding.DecodeString
|
|
)
|
|
|
|
// DecryptFailedError is returned when decryption fails.
|
|
type DecryptFailedError struct {
|
|
Inner error
|
|
}
|
|
|
|
func (e *DecryptFailedError) Error() string {
|
|
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
|
|
}
|
|
|
|
// New creates a database.Store wrapper that encrypts/decrypts values
|
|
// stored at rest in the database.
|
|
func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) {
|
|
cm := make(map[string]Cipher)
|
|
for _, c := range ciphers {
|
|
cm[c.HexDigest()] = c
|
|
}
|
|
dbc := &dbCrypt{
|
|
ciphers: cm,
|
|
Store: db,
|
|
}
|
|
if len(ciphers) > 0 {
|
|
dbc.primaryCipherDigest = ciphers[0].HexDigest()
|
|
}
|
|
// nolint: gocritic // This is allowed.
|
|
authCtx := dbauthz.AsSystemRestricted(ctx)
|
|
if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil {
|
|
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
|
|
}
|
|
return dbc, nil
|
|
}
|
|
|
|
type dbCrypt struct {
|
|
// primaryCipherDigest is the digest of the primary cipher used for encrypting data.
|
|
primaryCipherDigest string
|
|
// ciphers is a map of cipher digests to ciphers.
|
|
ciphers map[string]Cipher
|
|
database.Store
|
|
}
|
|
|
|
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *database.TxOptions) error {
|
|
return db.Store.InTx(func(s database.Store) error {
|
|
return function(&dbCrypt{
|
|
primaryCipherDigest: db.primaryCipherDigest,
|
|
ciphers: db.ciphers,
|
|
Store: s,
|
|
})
|
|
}, txOpts)
|
|
}
|
|
|
|
func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
|
ks, err := db.Store.GetDBCryptKeys(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Decrypt the test field to ensure that the key is valid.
|
|
for i := range ks {
|
|
if !ks[i].ActiveKeyDigest.Valid {
|
|
// Key has been revoked. We can't decrypt the test field, but
|
|
// we need to return it so that the caller knows that the key
|
|
// has been revoked.
|
|
continue
|
|
}
|
|
if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return ks, nil
|
|
}
|
|
|
|
// This does not need any special handling as it does not touch any encrypted fields.
|
|
// Explicitly defining this here to avoid confusion.
|
|
func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
|
return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest)
|
|
}
|
|
|
|
func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
|
|
// It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here.
|
|
var digest sql.NullString
|
|
err := db.encryptField(&arg.Test, &digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
arg.ActiveKeyDigest = digest.String
|
|
return db.Store.InsertDBCryptKey(ctx, arg)
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
|
|
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
|
links, err := db.Store.GetUserLinksByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for idx := range links {
|
|
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return links, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
|
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
link, err := db.Store.InsertUserLink(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
link, err := db.Store.UpdateUserLink(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertExternalAuthLink(ctx context.Context, params database.InsertExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
link, err := db.Store.InsertExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetExternalAuthLink(ctx context.Context, params database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
link, err := db.Store.GetExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) {
|
|
links, err := db.Store.GetExternalAuthLinksByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for idx := range links {
|
|
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return links, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
link, err := db.Store.UpdateExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
|
|
// The SQL query uses an optimistic lock:
|
|
// WHERE oauth_refresh_token = @old_oauth_refresh_token
|
|
// The caller supplies the plaintext old token (since dbcrypt
|
|
// decrypts on read), but the DB stores the encrypted value.
|
|
// Because AES-GCM is non-deterministic, we cannot simply
|
|
// re-encrypt the old token — the ciphertext would differ.
|
|
// Instead, read the current row from the inner (raw) store
|
|
// and use the actual encrypted value for the WHERE clause.
|
|
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
|
|
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
|
ProviderID: params.ProviderID,
|
|
UserID: params.UserID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Decrypt the stored token so we can compare with the
|
|
// caller-supplied plaintext.
|
|
decrypted := raw.OAuthRefreshToken
|
|
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
|
|
return err
|
|
}
|
|
if decrypted != params.OldOauthRefreshToken {
|
|
// The token has changed since the caller read it;
|
|
// the optimistic lock should fail (no rows updated).
|
|
// Return nil to match the :exec semantics of the SQL
|
|
// query, which silently updates zero rows.
|
|
return nil
|
|
}
|
|
// Use the raw encrypted value so the WHERE clause matches.
|
|
params.OldOauthRefreshToken = raw.OAuthRefreshToken
|
|
}
|
|
|
|
// We would normally use a sql.NullString here, but sqlc does not want to make
|
|
// a params struct with a nullable string.
|
|
var digest sql.NullString
|
|
if params.OAuthRefreshTokenKeyID != "" {
|
|
digest.String = params.OAuthRefreshTokenKeyID
|
|
digest.Valid = true
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, &digest); err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
|
|
keys, err := db.Store.GetCryptoKeys(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range keys {
|
|
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) {
|
|
key, err := db.Store.GetLatestCryptoKeyByFeature(ctx, feature)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
|
|
key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
|
if err := db.encryptField(¶ms.Secret.String, ¶ms.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
key, err := db.Store.InsertCryptoKey(ctx, params)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
|
key, err := db.Store.UpdateCryptoKeyDeletesAt(ctx, arg)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeysByFeature(ctx context.Context, feature database.CryptoKeyFeature) ([]database.CryptoKey, error) {
|
|
keys, err := db.Store.GetCryptoKeysByFeature(ctx, feature)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range keys {
|
|
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
|
provider, err := db.Store.GetChatProviderByID(ctx, id)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
|
|
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
|
providers, err := db.Store.GetChatProviders(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range providers {
|
|
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
|
providers, err := db.Store.GetEnabledChatProviders(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range providers {
|
|
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
|
|
provider, err := db.Store.InsertChatProvider(ctx, params)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
|
|
provider, err := db.Store.UpdateChatProvider(ctx, params)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
// decryptMCPServerConfig decrypts all encrypted fields on a
|
|
// single MCPServerConfig in place.
|
|
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
|
if err := db.decryptField(&cfg.OAuth2ClientSecret, cfg.OAuth2ClientSecretKeyID); err != nil {
|
|
return err
|
|
}
|
|
if err := db.decryptField(&cfg.APIKeyValue, cfg.APIKeyValueKeyID); err != nil {
|
|
return err
|
|
}
|
|
return db.decryptField(&cfg.CustomHeaders, cfg.CustomHeadersKeyID)
|
|
}
|
|
|
|
// decryptMCPServerUserToken decrypts all encrypted fields on a
|
|
// single MCPServerUserToken in place.
|
|
func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) error {
|
|
if err := db.decryptField(&tok.AccessToken, tok.AccessTokenKeyID); err != nil {
|
|
return err
|
|
}
|
|
return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID)
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
|
cfg, err := db.Store.GetMCPServerConfigByID(ctx, id)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
|
cfg, err := db.Store.GetMCPServerConfigBySlug(ctx, slug)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetMCPServerConfigsByIDs(ctx, ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetEnabledMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetForcedMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
|
tok, err := db.Store.GetMCPServerUserToken(ctx, arg)
|
|
if err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
return tok, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
|
toks, err := db.Store.GetMCPServerUserTokensByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range toks {
|
|
if err := db.decryptMCPServerUserToken(&toks[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return toks, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
|
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
|
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.APIKeyValue) == "" {
|
|
params.APIKeyValueKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.CustomHeaders) == "" {
|
|
params.CustomHeadersKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
|
|
cfg, err := db.Store.InsertMCPServerConfig(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
|
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
|
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.APIKeyValue) == "" {
|
|
params.APIKeyValueKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.CustomHeaders) == "" {
|
|
params.CustomHeadersKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
|
|
cfg, err := db.Store.UpdateMCPServerConfig(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
|
if strings.TrimSpace(params.AccessToken) == "" {
|
|
params.AccessTokenKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.AccessToken, ¶ms.AccessTokenKeyID); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if strings.TrimSpace(params.RefreshToken) == "" {
|
|
params.RefreshTokenKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.RefreshToken, ¶ms.RefreshTokenKeyID); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
|
|
tok, err := db.Store.UpsertMCPServerUserToken(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
return tok, nil
|
|
}
|
|
|
|
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
|
|
// If no cipher is loaded, then we can't encrypt anything!
|
|
if db.ciphers == nil || db.primaryCipherDigest == "" {
|
|
return nil
|
|
}
|
|
|
|
if field == nil {
|
|
return xerrors.Errorf("developer error: encryptField called with nil field")
|
|
}
|
|
if digest == nil {
|
|
return xerrors.Errorf("developer error: encryptField called with nil digest")
|
|
}
|
|
|
|
encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Base64 is used to support UTF-8 encoding in PostgreSQL.
|
|
*field = b64encode(encrypted)
|
|
*digest = sql.NullString{String: db.primaryCipherDigest, Valid: true}
|
|
return nil
|
|
}
|
|
|
|
// decryptFields decrypts the given field using the key with the given digest.
|
|
// If the value fails to decrypt, sql.ErrNoRows will be returned.
|
|
func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error {
|
|
if field == nil {
|
|
return xerrors.Errorf("developer error: decryptField called with nil field")
|
|
}
|
|
|
|
if !digest.Valid || digest.String == "" {
|
|
// This field is not encrypted.
|
|
return nil
|
|
}
|
|
|
|
key, ok := db.ciphers[digest.String]
|
|
if !ok {
|
|
return &DecryptFailedError{
|
|
Inner: xerrors.Errorf("no cipher with digest %q", digest.String),
|
|
}
|
|
}
|
|
|
|
data, err := b64decode(*field)
|
|
if err != nil {
|
|
// If it's not valid base64, we should complain loudly.
|
|
return &DecryptFailedError{
|
|
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
|
|
}
|
|
}
|
|
decrypted, err := key.Decrypt(data)
|
|
if err != nil {
|
|
return &DecryptFailedError{Inner: err}
|
|
}
|
|
*field = string(decrypted)
|
|
return nil
|
|
}
|
|
|
|
func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error {
|
|
var err error
|
|
for i := 0; i < 3; i++ {
|
|
err = db.ensureEncrypted(ctx)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// If we get a serialization error, then we need to retry.
|
|
if !database.IsSerializedError(err) {
|
|
return err
|
|
}
|
|
// otherwise, retry
|
|
}
|
|
// If we get here, then we ran out of retries
|
|
return err
|
|
}
|
|
|
|
func (db *dbCrypt) ensureEncrypted(ctx context.Context) error {
|
|
return db.InTx(func(s database.Store) error {
|
|
// Attempt to read the encrypted test fields of the currently active keys.
|
|
ks, err := s.GetDBCryptKeys(ctx)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
|
|
var highestNumber int32
|
|
var activeCipherFound bool
|
|
for _, k := range ks {
|
|
// If our primary key has been revoked, then we can't do anything.
|
|
if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest {
|
|
return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest)
|
|
}
|
|
|
|
if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest {
|
|
activeCipherFound = true
|
|
}
|
|
|
|
if k.Number > highestNumber {
|
|
highestNumber = k.Number
|
|
}
|
|
}
|
|
|
|
if activeCipherFound || len(db.ciphers) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// If we get here, then we have a new key that we need to insert.
|
|
return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
|
Number: highestNumber + 1,
|
|
ActiveKeyDigest: db.primaryCipherDigest,
|
|
Test: testValue,
|
|
})
|
|
}, &database.TxOptions{Isolation: sql.LevelRepeatableRead})
|
|
}
|