mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(enterprise/dbcrypt): rotate decrypt and delete MCP server secrets
The dbcrypt CLI iterated user_links, external_auth_links, user_secrets, ai_providers, ai_provider_keys, and user_ai_provider_keys, but not the three MCP tables that the interceptor encrypts: mcp_server_configs, mcp_server_user_tokens, and mcp_server_user_header_values. After `server dbcrypt rotate` or `decrypt` revoked the prior cipher, MCP rows still referenced the revoked digest and became unreadable. Add per-row updates that re-encrypt or clear the MCP secrets before the old keys are revoked, and extend the destructive `delete` SQL to clear the three tables. A new UpdateEncryptedMCPServerConfig query plus its dbcrypt wrapper move the three mcp_server_configs columns (and their key_id pointers) atomically. Extend TestServerDBCrypt to seed an MCP server config plus the per-user token and header rows for each user, and assert that every rotate / decrypt / delete step touches them correctly.
This commit is contained in:
@@ -6956,6 +6956,17 @@ func (q *querier) UpdateEncryptedAIProviderSettings(ctx context.Context, arg dat
|
||||
return q.db.UpdateEncryptedAIProviderSettings(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateEncryptedMCPServerConfig(ctx context.Context, arg database.UpdateEncryptedMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
// Updates only the encrypted columns so the dbcrypt rotation
|
||||
// utility can move every secret to a new cipher digest before old
|
||||
// keys are revoked. Treated as an admin-only operation just like
|
||||
// the regular MCP server config update path.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.UpdateEncryptedMCPServerConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) {
|
||||
// Encrypted user-owned provider keys can be rewritten on any row so
|
||||
// dbcrypt rotation can move every key to a new digest. This is a
|
||||
|
||||
@@ -1702,6 +1702,17 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteMCPServerUserHeaderValuesByConfigID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateEncryptedMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
cfg := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
arg := database.UpdateEncryptedMCPServerConfigParams{
|
||||
ID: cfg.ID,
|
||||
OAuth2ClientSecret: "encrypted-secret",
|
||||
APIKeyValue: "encrypted-api-key",
|
||||
CustomHeaders: `{"X-Foo":"encrypted"}`,
|
||||
}
|
||||
dbm.EXPECT().UpdateEncryptedMCPServerConfig(gomock.Any(), arg).Return(cfg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(cfg)
|
||||
}))
|
||||
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertMCPServerConfigParams{
|
||||
DisplayName: "Test MCP Server",
|
||||
|
||||
+8
@@ -4977,6 +4977,14 @@ func (m queryMetricsStore) UpdateEncryptedAIProviderSettings(ctx context.Context
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateEncryptedMCPServerConfig(ctx context.Context, arg database.UpdateEncryptedMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateEncryptedMCPServerConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateEncryptedMCPServerConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateEncryptedMCPServerConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateEncryptedUserAIProviderKey(ctx, arg)
|
||||
|
||||
Generated
+15
@@ -9413,6 +9413,21 @@ func (mr *MockStoreMockRecorder) UpdateEncryptedAIProviderSettings(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEncryptedAIProviderSettings", reflect.TypeOf((*MockStore)(nil).UpdateEncryptedAIProviderSettings), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateEncryptedMCPServerConfig mocks base method.
|
||||
func (m *MockStore) UpdateEncryptedMCPServerConfig(ctx context.Context, arg database.UpdateEncryptedMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateEncryptedMCPServerConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateEncryptedMCPServerConfig indicates an expected call of UpdateEncryptedMCPServerConfig.
|
||||
func (mr *MockStoreMockRecorder) UpdateEncryptedMCPServerConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEncryptedMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateEncryptedMCPServerConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateEncryptedUserAIProviderKey mocks base method.
|
||||
func (m *MockStore) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+6
@@ -1246,6 +1246,12 @@ type sqlcQuerier interface {
|
||||
// Used by the dbcrypt key rotation utility to re-encrypt or decrypt
|
||||
// rows in place.
|
||||
UpdateEncryptedAIProviderSettings(ctx context.Context, arg UpdateEncryptedAIProviderSettingsParams) (AIProvider, error)
|
||||
// Updates only the encrypted columns (oauth2_client_secret,
|
||||
// api_key_value, custom_headers) and their per-row key_id pointers,
|
||||
// plus the updated_at timestamp. Used by the dbcrypt key rotation
|
||||
// utility to re-encrypt or decrypt rows in place so old ciphers can
|
||||
// be revoked without orphaning MCP secrets.
|
||||
UpdateEncryptedMCPServerConfig(ctx context.Context, arg UpdateEncryptedMCPServerConfigParams) (MCPServerConfig, error)
|
||||
UpdateEncryptedUserAIProviderKey(ctx context.Context, arg UpdateEncryptedUserAIProviderKeyParams) (UserAiProviderKey, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
// Optimistic lock: only update the row if the refresh token in the database
|
||||
|
||||
Generated
+80
@@ -15772,6 +15772,86 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateEncryptedMCPServerConfig = `-- name: UpdateEncryptedMCPServerConfig :one
|
||||
UPDATE
|
||||
mcp_server_configs
|
||||
SET
|
||||
oauth2_client_secret = $1::text,
|
||||
oauth2_client_secret_key_id = $2::text,
|
||||
api_key_value = $3::text,
|
||||
api_key_value_key_id = $4::text,
|
||||
custom_headers = $5::text,
|
||||
custom_headers_key_id = $6::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $7::uuid
|
||||
RETURNING
|
||||
id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions
|
||||
`
|
||||
|
||||
type UpdateEncryptedMCPServerConfigParams struct {
|
||||
OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"`
|
||||
OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"`
|
||||
APIKeyValue string `db:"api_key_value" json:"api_key_value"`
|
||||
APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"`
|
||||
CustomHeaders string `db:"custom_headers" json:"custom_headers"`
|
||||
CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
// Updates only the encrypted columns (oauth2_client_secret,
|
||||
// api_key_value, custom_headers) and their per-row key_id pointers,
|
||||
// plus the updated_at timestamp. Used by the dbcrypt key rotation
|
||||
// utility to re-encrypt or decrypt rows in place so old ciphers can
|
||||
// be revoked without orphaning MCP secrets.
|
||||
func (q *sqlQuerier) UpdateEncryptedMCPServerConfig(ctx context.Context, arg UpdateEncryptedMCPServerConfigParams) (MCPServerConfig, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateEncryptedMCPServerConfig,
|
||||
arg.OAuth2ClientSecret,
|
||||
arg.OAuth2ClientSecretKeyID,
|
||||
arg.APIKeyValue,
|
||||
arg.APIKeyValueKeyID,
|
||||
arg.CustomHeaders,
|
||||
arg.CustomHeadersKeyID,
|
||||
arg.ID,
|
||||
)
|
||||
var i MCPServerConfig
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.DisplayName,
|
||||
&i.Slug,
|
||||
&i.Description,
|
||||
&i.IconURL,
|
||||
&i.Transport,
|
||||
&i.Url,
|
||||
&i.AuthType,
|
||||
&i.OAuth2ClientID,
|
||||
&i.OAuth2ClientSecret,
|
||||
&i.OAuth2ClientSecretKeyID,
|
||||
&i.OAuth2AuthURL,
|
||||
&i.OAuth2TokenURL,
|
||||
&i.OAuth2Scopes,
|
||||
&i.APIKeyHeader,
|
||||
&i.APIKeyValue,
|
||||
&i.APIKeyValueKeyID,
|
||||
&i.CustomHeaders,
|
||||
&i.CustomHeadersKeyID,
|
||||
pq.Array(&i.ToolAllowList),
|
||||
pq.Array(&i.ToolDenyList),
|
||||
&i.Availability,
|
||||
&i.Enabled,
|
||||
&i.CreatedBy,
|
||||
&i.UpdatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ModelIntent,
|
||||
&i.AllowInPlanMode,
|
||||
&i.ForwardCoderHeaders,
|
||||
pq.Array(&i.CustomHeadersUserKeys),
|
||||
&i.CustomHeadersUserKeyDescriptions,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateMCPServerConfig = `-- name: UpdateMCPServerConfig :one
|
||||
UPDATE
|
||||
mcp_server_configs
|
||||
|
||||
@@ -162,6 +162,27 @@ DELETE FROM
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: UpdateEncryptedMCPServerConfig :one
|
||||
-- Updates only the encrypted columns (oauth2_client_secret,
|
||||
-- api_key_value, custom_headers) and their per-row key_id pointers,
|
||||
-- plus the updated_at timestamp. Used by the dbcrypt key rotation
|
||||
-- utility to re-encrypt or decrypt rows in place so old ciphers can
|
||||
-- be revoked without orphaning MCP secrets.
|
||||
UPDATE
|
||||
mcp_server_configs
|
||||
SET
|
||||
oauth2_client_secret = @oauth2_client_secret::text,
|
||||
oauth2_client_secret_key_id = sqlc.narg('oauth2_client_secret_key_id')::text,
|
||||
api_key_value = @api_key_value::text,
|
||||
api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text,
|
||||
custom_headers = @custom_headers::text,
|
||||
custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetMCPServerUserToken :one
|
||||
SELECT
|
||||
*
|
||||
|
||||
@@ -204,6 +204,23 @@ func TestServerDBCrypt(t *testing.T) {
|
||||
userSecrets, err := db.ListUserSecretsWithValues(ctx, usr.ID)
|
||||
require.NoError(t, err, "failed to get user secrets for user %s", usr.ID)
|
||||
require.Empty(t, userSecrets)
|
||||
|
||||
mcpUserTokens, err := db.GetMCPServerUserTokensByUserID(ctx, usr.ID)
|
||||
require.NoError(t, err, "failed to get mcp server user tokens for user %s", usr.ID)
|
||||
require.Empty(t, mcpUserTokens)
|
||||
|
||||
mcpHeaderRows, err := db.GetMCPServerUserHeaderValuesByUserID(ctx, usr.ID)
|
||||
require.NoError(t, err, "failed to get mcp server user header values for user %s", usr.ID)
|
||||
require.Empty(t, mcpHeaderRows)
|
||||
|
||||
mcpConfig, err := db.GetMCPServerConfigBySlug(ctx, "mcp-"+usr.ID.String())
|
||||
require.NoError(t, err, "failed to get mcp server config for user %s", usr.ID)
|
||||
require.Empty(t, mcpConfig.OAuth2ClientSecret, "mcp_server_configs.oauth2_client_secret should be cleared")
|
||||
require.False(t, mcpConfig.OAuth2ClientSecretKeyID.Valid, "mcp_server_configs.oauth2_client_secret_key_id should be NULL")
|
||||
require.Empty(t, mcpConfig.APIKeyValue, "mcp_server_configs.api_key_value should be cleared")
|
||||
require.False(t, mcpConfig.APIKeyValueKeyID.Valid, "mcp_server_configs.api_key_value_key_id should be NULL")
|
||||
require.Equal(t, "{}", mcpConfig.CustomHeaders, "mcp_server_configs.custom_headers should be reset to {}")
|
||||
require.False(t, mcpConfig.CustomHeadersKeyID.Valid, "mcp_server_configs.custom_headers_key_id should be NULL")
|
||||
}
|
||||
|
||||
// Validate that the key has been revoked in the database.
|
||||
@@ -256,6 +273,18 @@ func genData(t *testing.T, db database.Store) []database.User {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Each user owns one MCP server config so dbcrypt rotate /
|
||||
// decrypt / delete exercise mcp_server_configs alongside the
|
||||
// per-user MCP rows below.
|
||||
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
||||
Slug: "mcp-" + usr.ID.String(),
|
||||
OAuth2ClientSecret: "mcp-oauth2-secret-" + usr.ID.String(),
|
||||
APIKeyValue: "mcp-api-key-" + usr.ID.String(),
|
||||
CustomHeaders: "mcp-custom-headers-" + usr.ID.String(),
|
||||
CreatedBy: uuid.NullUUID{UUID: usr.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: usr.ID, Valid: true},
|
||||
})
|
||||
|
||||
// Deleted users cannot have user_links or user_secrets.
|
||||
if !deleted {
|
||||
// Fun fact: our schema allows _all_ login types to have
|
||||
@@ -275,6 +304,21 @@ func genData(t *testing.T, db database.Store) []database.User {
|
||||
EnvName: "",
|
||||
FilePath: "",
|
||||
})
|
||||
|
||||
_, err := db.UpsertMCPServerUserToken(context.Background(), database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: mcpConfig.ID,
|
||||
UserID: usr.ID,
|
||||
AccessToken: "mcp-access-" + usr.ID.String(),
|
||||
RefreshToken: "mcp-refresh-" + usr.ID.String(),
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = dbgen.MCPServerUserHeaderValues(t, db, database.McpServerUserHeaderValue{
|
||||
MCPServerConfigID: mcpConfig.ID,
|
||||
UserID: usr.ID,
|
||||
HeaderValues: "mcp-headers-" + usr.ID.String(),
|
||||
})
|
||||
}
|
||||
users = append(users, usr)
|
||||
}
|
||||
@@ -354,6 +398,31 @@ func requireEncryptedWithCipher(ctx context.Context, t *testing.T, db database.S
|
||||
require.Len(t, userAIProviderKeys, 1)
|
||||
requireEncryptedEquals(t, c, "user-ai-provider-key-"+userID.String(), userAIProviderKeys[0].APIKey)
|
||||
require.Equal(t, c.HexDigest(), userAIProviderKeys[0].ApiKeyKeyID.String)
|
||||
|
||||
mcpConfig, err := db.GetMCPServerConfigBySlug(ctx, "mcp-"+userID.String())
|
||||
require.NoError(t, err, "failed to get mcp server config for user %s", userID)
|
||||
requireEncryptedEquals(t, c, "mcp-oauth2-secret-"+userID.String(), mcpConfig.OAuth2ClientSecret)
|
||||
require.Equal(t, c.HexDigest(), mcpConfig.OAuth2ClientSecretKeyID.String)
|
||||
requireEncryptedEquals(t, c, "mcp-api-key-"+userID.String(), mcpConfig.APIKeyValue)
|
||||
require.Equal(t, c.HexDigest(), mcpConfig.APIKeyValueKeyID.String)
|
||||
requireEncryptedEquals(t, c, "mcp-custom-headers-"+userID.String(), mcpConfig.CustomHeaders)
|
||||
require.Equal(t, c.HexDigest(), mcpConfig.CustomHeadersKeyID.String)
|
||||
|
||||
mcpUserTokens, err := db.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
require.NoError(t, err, "failed to get mcp server user tokens for user %s", userID)
|
||||
for _, tok := range mcpUserTokens {
|
||||
requireEncryptedEquals(t, c, "mcp-access-"+userID.String(), tok.AccessToken)
|
||||
require.Equal(t, c.HexDigest(), tok.AccessTokenKeyID.String)
|
||||
requireEncryptedEquals(t, c, "mcp-refresh-"+userID.String(), tok.RefreshToken)
|
||||
require.Equal(t, c.HexDigest(), tok.RefreshTokenKeyID.String)
|
||||
}
|
||||
|
||||
mcpHeaderRows, err := db.GetMCPServerUserHeaderValuesByUserID(ctx, userID)
|
||||
require.NoError(t, err, "failed to get mcp server user header values for user %s", userID)
|
||||
for _, row := range mcpHeaderRows {
|
||||
requireEncryptedEquals(t, c, "mcp-headers-"+userID.String(), row.HeaderValues)
|
||||
require.Equal(t, c.HexDigest(), row.HeaderValuesKeyID.String)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerAIProviderKeysEncryptedWithDBCrypt starts a real enterprise server
|
||||
|
||||
@@ -101,6 +101,51 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
log.Debug(ctx, "rotated user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
mcpUserTokens, err := cryptTx.GetMCPServerUserTokensByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server user tokens for user %s: %w", uid, err)
|
||||
}
|
||||
for _, tok := range mcpUserTokens {
|
||||
if tok.AccessTokenKeyID.Valid && tok.AccessTokenKeyID.String == ciphers[0].HexDigest() &&
|
||||
tok.RefreshTokenKeyID.Valid && tok.RefreshTokenKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping mcp server user token", slog.F("user_id", uid), slog.F("mcp_server_config_id", tok.MCPServerConfigID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptTx.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: tok.MCPServerConfigID,
|
||||
UserID: uid,
|
||||
AccessToken: tok.AccessToken,
|
||||
AccessTokenKeyID: sql.NullString{}, // dbcrypt will re-encrypt
|
||||
RefreshToken: tok.RefreshToken,
|
||||
RefreshTokenKeyID: sql.NullString{}, // dbcrypt will re-encrypt
|
||||
TokenType: tok.TokenType,
|
||||
Expiry: tok.Expiry,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("rotate mcp server user token user_id=%s mcp_server_config_id=%s: %w", uid, tok.MCPServerConfigID, err)
|
||||
}
|
||||
log.Debug(ctx, "rotated mcp server user token", slog.F("user_id", uid), slog.F("mcp_server_config_id", tok.MCPServerConfigID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
mcpHeaderRows, err := cryptTx.GetMCPServerUserHeaderValuesByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server user header values for user %s: %w", uid, err)
|
||||
}
|
||||
for _, row := range mcpHeaderRows {
|
||||
if row.HeaderValuesKeyID.Valid && row.HeaderValuesKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping mcp server user header values", slog.F("user_id", uid), slog.F("mcp_server_config_id", row.MCPServerConfigID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptTx.UpsertMCPServerUserHeaderValues(ctx, database.UpsertMCPServerUserHeaderValuesParams{
|
||||
MCPServerConfigID: row.MCPServerConfigID,
|
||||
UserID: uid,
|
||||
HeaderValues: row.HeaderValues,
|
||||
HeaderValuesKeyID: sql.NullString{}, // dbcrypt will re-encrypt
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("rotate mcp server user header values user_id=%s mcp_server_config_id=%s: %w", uid, row.MCPServerConfigID, err)
|
||||
}
|
||||
log.Debug(ctx, "rotated mcp server user header values", slog.F("user_id", uid), slog.F("mcp_server_config_id", row.MCPServerConfigID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &database.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
@@ -180,6 +225,33 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
log.Debug(ctx, "encrypted user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
mcpConfigs, err := cryptDB.GetMCPServerConfigs(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server configs: %w", err)
|
||||
}
|
||||
log.Info(ctx, "encrypting mcp server config secrets", slog.F("config_count", len(mcpConfigs)))
|
||||
for idx, cfg := range mcpConfigs {
|
||||
oauth2Current := cfg.OAuth2ClientSecretKeyID.Valid && cfg.OAuth2ClientSecretKeyID.String == ciphers[0].HexDigest()
|
||||
apiKeyCurrent := cfg.APIKeyValueKeyID.Valid && cfg.APIKeyValueKeyID.String == ciphers[0].HexDigest()
|
||||
customHeadersCurrent := cfg.CustomHeadersKeyID.Valid && cfg.CustomHeadersKeyID.String == ciphers[0].HexDigest()
|
||||
if oauth2Current && apiKeyCurrent && customHeadersCurrent {
|
||||
log.Debug(ctx, "skipping mcp server config", slog.F("mcp_server_config_id", cfg.ID), slog.F("slug", cfg.Slug), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateEncryptedMCPServerConfig(ctx, database.UpdateEncryptedMCPServerConfigParams{
|
||||
ID: cfg.ID,
|
||||
OAuth2ClientSecret: cfg.OAuth2ClientSecret,
|
||||
OAuth2ClientSecretKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
APIKeyValue: cfg.APIKeyValue,
|
||||
APIKeyValueKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
CustomHeaders: cfg.CustomHeaders,
|
||||
CustomHeadersKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update mcp server config id=%s slug=%s: %w", cfg.ID, cfg.Slug, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted mcp server config", slog.F("mcp_server_config_id", cfg.ID), slog.F("slug", cfg.Slug), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
// Revoke old keys
|
||||
for _, c := range ciphers[1:] {
|
||||
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
|
||||
@@ -288,6 +360,50 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
log.Debug(ctx, "decrypted user secret", slog.F("user_id", uid), slog.F("secret_name", secret.Name), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
mcpUserTokens, err := tx.GetMCPServerUserTokensByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server user tokens for user %s: %w", uid, err)
|
||||
}
|
||||
for _, tok := range mcpUserTokens {
|
||||
if !tok.AccessTokenKeyID.Valid && !tok.RefreshTokenKeyID.Valid {
|
||||
log.Debug(ctx, "skipping mcp server user token", slog.F("user_id", uid), slog.F("mcp_server_config_id", tok.MCPServerConfigID), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: tok.MCPServerConfigID,
|
||||
UserID: uid,
|
||||
AccessToken: tok.AccessToken,
|
||||
AccessTokenKeyID: sql.NullString{}, // clear the key ID
|
||||
RefreshToken: tok.RefreshToken,
|
||||
RefreshTokenKeyID: sql.NullString{}, // clear the key ID
|
||||
TokenType: tok.TokenType,
|
||||
Expiry: tok.Expiry,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("decrypt mcp server user token user_id=%s mcp_server_config_id=%s: %w", uid, tok.MCPServerConfigID, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted mcp server user token", slog.F("user_id", uid), slog.F("mcp_server_config_id", tok.MCPServerConfigID), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
mcpHeaderRows, err := tx.GetMCPServerUserHeaderValuesByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server user header values for user %s: %w", uid, err)
|
||||
}
|
||||
for _, row := range mcpHeaderRows {
|
||||
if !row.HeaderValuesKeyID.Valid {
|
||||
log.Debug(ctx, "skipping mcp server user header values", slog.F("user_id", uid), slog.F("mcp_server_config_id", row.MCPServerConfigID), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpsertMCPServerUserHeaderValues(ctx, database.UpsertMCPServerUserHeaderValuesParams{
|
||||
MCPServerConfigID: row.MCPServerConfigID,
|
||||
UserID: uid,
|
||||
HeaderValues: row.HeaderValues,
|
||||
HeaderValuesKeyID: sql.NullString{}, // clear the key ID
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("decrypt mcp server user header values user_id=%s mcp_server_config_id=%s: %w", uid, row.MCPServerConfigID, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted mcp server user header values", slog.F("user_id", uid), slog.F("mcp_server_config_id", row.MCPServerConfigID), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &database.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
@@ -358,6 +474,30 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
log.Debug(ctx, "decrypted user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
mcpConfigs, err := cryptDB.GetMCPServerConfigs(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get mcp server configs: %w", err)
|
||||
}
|
||||
log.Info(ctx, "decrypting mcp server config secrets", slog.F("config_count", len(mcpConfigs)))
|
||||
for idx, cfg := range mcpConfigs {
|
||||
if !cfg.OAuth2ClientSecretKeyID.Valid && !cfg.APIKeyValueKeyID.Valid && !cfg.CustomHeadersKeyID.Valid {
|
||||
log.Debug(ctx, "skipping mcp server config", slog.F("mcp_server_config_id", cfg.ID), slog.F("slug", cfg.Slug), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateEncryptedMCPServerConfig(ctx, database.UpdateEncryptedMCPServerConfigParams{
|
||||
ID: cfg.ID,
|
||||
OAuth2ClientSecret: cfg.OAuth2ClientSecret,
|
||||
OAuth2ClientSecretKeyID: sql.NullString{}, // explicitly clear the key id
|
||||
APIKeyValue: cfg.APIKeyValue,
|
||||
APIKeyValueKeyID: sql.NullString{}, // explicitly clear the key id
|
||||
CustomHeaders: cfg.CustomHeaders,
|
||||
CustomHeadersKeyID: sql.NullString{}, // explicitly clear the key id
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("decrypt mcp server config id=%s slug=%s: %w", cfg.ID, cfg.Slug, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted mcp server config", slog.F("mcp_server_config_id", cfg.ID), slog.F("slug", cfg.Slug), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
// Revoke _all_ keys
|
||||
for _, c := range ciphers {
|
||||
if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil {
|
||||
@@ -388,6 +528,21 @@ UPDATE ai_providers
|
||||
WHERE settings_key_id IS NOT NULL;
|
||||
DELETE FROM ai_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
DELETE FROM mcp_server_user_tokens
|
||||
WHERE access_token_key_id IS NOT NULL
|
||||
OR refresh_token_key_id IS NOT NULL;
|
||||
DELETE FROM mcp_server_user_header_values
|
||||
WHERE header_values_key_id IS NOT NULL;
|
||||
UPDATE mcp_server_configs
|
||||
SET oauth2_client_secret = '',
|
||||
oauth2_client_secret_key_id = NULL,
|
||||
api_key_value = '',
|
||||
api_key_value_key_id = NULL,
|
||||
custom_headers = '{}',
|
||||
custom_headers_key_id = NULL
|
||||
WHERE oauth2_client_secret_key_id IS NOT NULL
|
||||
OR api_key_value_key_id IS NOT NULL
|
||||
OR custom_headers_key_id IS NOT NULL;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
|
||||
@@ -833,6 +833,38 @@ func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.In
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// UpdateEncryptedMCPServerConfig re-encrypts the three MCP server
|
||||
// secret columns (oauth2_client_secret, api_key_value, custom_headers)
|
||||
// on a row so dbcrypt key rotation can move every cipher digest
|
||||
// before old keys are revoked. Empty inputs clear the matching
|
||||
// key_id so the row stays internally consistent.
|
||||
func (db *dbCrypt) UpdateEncryptedMCPServerConfig(ctx context.Context, params database.UpdateEncryptedMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
||||
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.APIKeyValue) == "" {
|
||||
params.APIKeyValueKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if strings.TrimSpace(params.CustomHeaders) == "" {
|
||||
params.CustomHeadersKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
|
||||
cfg, err := db.Store.UpdateEncryptedMCPServerConfig(ctx, params)
|
||||
if err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
||||
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
||||
|
||||
Reference in New Issue
Block a user