feat: remove legacy chat provider tables (#25416)

This commit is contained in:
Michael Suchacz
2026-05-22 09:50:01 +02:00
committed by GitHub
parent ddec110b0e
commit ca1f6b19a2
46 changed files with 1270 additions and 3505 deletions
+29 -20
View File
@@ -122,14 +122,20 @@ func seedChatDependencies(
UserID: user.ID,
OrganizationID: org.ID,
})
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
BaseUrl: safetyNet.URL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
provider := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "test-" + uuid.NewString(),
BaseUrl: safetyNet.URL,
})
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: provider.ID,
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
IsDefault: true,
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
IsDefault: true,
})
return user, org, model
}
@@ -186,21 +192,24 @@ func setOpenAIProviderBaseURL(
) {
t.Helper()
provider, err := db.GetChatProviderByProvider(ctx, "openai")
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
CentralApiKeyEnabled: true,
AllowUserApiKey: false,
AllowCentralApiKeyFallback: false,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
})
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
require.NoError(t, err)
for _, provider := range providers {
if provider.Type != database.AiProviderTypeOpenai {
continue
}
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
Enabled: provider.Enabled,
BaseUrl: baseURL,
Settings: provider.Settings,
SettingsKeyID: provider.SettingsKeyID,
})
require.NoError(t, err)
return
}
require.Fail(t, "openai provider not found")
}
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
+2 -105
View File
@@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
}
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if strings.TrimSpace(userProviderKey.APIKey) == "" {
continue
}
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
@@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if strings.TrimSpace(provider.APIKey) == "" {
continue
}
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
if err != nil {
return xerrors.Errorf("get ai providers: %w", err)
@@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
}
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if !userProviderKey.ApiKeyKeyID.Valid {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
}
userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
@@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if !provider.ApiKeyKeyID.Valid {
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
if err != nil {
return xerrors.Errorf("get ai providers: %w", err)
@@ -475,16 +378,10 @@ DELETE FROM user_links
DELETE FROM external_auth_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL;
DELETE FROM user_chat_provider_keys
WHERE api_key_key_id IS NOT NULL;
DELETE FROM user_ai_provider_keys
WHERE api_key_key_id IS NOT NULL;
DELETE FROM user_secrets
WHERE value_key_id IS NOT NULL;
UPDATE chat_providers
SET api_key = '',
api_key_key_id = NULL
WHERE api_key_key_id IS NOT NULL;
UPDATE ai_providers
SET settings = NULL,
settings_key_id = NULL
@@ -502,9 +399,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
store := database.New(sqlDB)
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
if err != nil {
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err)
}
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
log.Info(ctx, "deleted encrypted user tokens and AI provider API keys")
log.Info(ctx, "revoking all active keys")
keys, err := store.GetDBCryptKeys(ctx)
+13 -137
View File
@@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID
return keys, nil
}
func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptAIProviderKey(&keys[i]); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
@@ -576,92 +589,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data
return key, nil
}
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByID(ctx, id)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetEnabledChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.InsertChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.UpdateChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error {
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
}
@@ -754,57 +681,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params
return key, nil
}
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
}
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
// decryptMCPServerConfig decrypts all encrypted fields on a
// single MCPServerConfig in place.
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
+12 -107
View File
@@ -1281,6 +1281,18 @@ func TestAIProviderKeys(t *testing.T) {
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
})
t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID})
require.NoError(t, err)
require.Len(t, keys, 1)
requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey)
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
})
t.Run("DeleteAIProviderKey", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
@@ -1558,113 +1570,6 @@ func TestMCPServerUserTokens(t *testing.T) {
})
}
func TestUserChatProviderKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
//nolint:gosec // test credentials
initialAPIKey = "sk-initial-api-key-value"
//nolint:gosec // test credentials
updatedAPIKey = "sk-updated-api-key-value"
)
insertProviderAndKey := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) (database.ChatProvider, database.UserChatProviderKey) {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{
AllowUserApiKey: true,
}, func(params *database.InsertChatProviderParams) {
params.APIKey = ""
})
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: user.ID,
ChatProviderID: provider.ID,
APIKey: initialAPIKey,
})
require.NoError(t, err)
require.Equal(t, initialAPIKey, key.APIKey)
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
return provider, key
}
getUserChatProviderKey := func(t *testing.T, store interface {
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
}, userID uuid.UUID, providerID uuid.UUID,
) database.UserChatProviderKey {
t.Helper()
keys, err := store.GetUserChatProviderKeys(ctx, userID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, providerID, keys[0].ChatProviderID)
return keys[0]
}
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, key.ID, got.ID)
require.Equal(t, initialAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
})
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
t.Parallel()
_, crypt, ciphers := setup(t)
_, key := insertProviderAndKey(t, crypt, ciphers)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, key.ID, keys[0].ID)
require.Equal(t, initialAPIKey, keys[0].APIKey)
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
})
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: key.UserID,
ChatProviderID: provider.ID,
APIKey: updatedAPIKey,
})
require.NoError(t, err)
require.Equal(t, key.ID, updated.ID)
require.Equal(t, key.CreatedAt, updated.CreatedAt)
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
require.Equal(t, updatedAPIKey, updated.APIKey)
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, updatedAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, updatedAPIKey, keys[0].APIKey)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
})
}
func TestUserSecrets(t *testing.T) {
t.Parallel()
ctx := context.Background()