mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: remove legacy chat provider tables (#25416)
This commit is contained in:
@@ -122,14 +122,20 @@ func seedChatDependencies(
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
BaseUrl: safetyNet.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "test-" + uuid.NewString(),
|
||||
BaseUrl: safetyNet.URL,
|
||||
})
|
||||
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: provider.ID,
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
IsDefault: true,
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
IsDefault: true,
|
||||
})
|
||||
return user, org, model
|
||||
}
|
||||
@@ -186,21 +192,24 @@ func setOpenAIProviderBaseURL(
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
provider, err := db.GetChatProviderByProvider(ctx, "openai")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowUserApiKey: false,
|
||||
AllowCentralApiKeyFallback: false,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
})
|
||||
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
require.NoError(t, err)
|
||||
for _, provider := range providers {
|
||||
if provider.Type != database.AiProviderTypeOpenai {
|
||||
continue
|
||||
}
|
||||
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
Enabled: provider.Enabled,
|
||||
BaseUrl: baseURL,
|
||||
Settings: provider.Settings,
|
||||
SettingsKeyID: provider.SettingsKeyID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
require.Fail(t, "openai provider not found")
|
||||
}
|
||||
|
||||
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
|
||||
|
||||
@@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if strings.TrimSpace(userProviderKey.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
|
||||
@@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get ai providers: %w", err)
|
||||
@@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if !userProviderKey.ApiKeyKeyID.Valid {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
|
||||
@@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if !provider.ApiKeyKeyID.Valid {
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get ai providers: %w", err)
|
||||
@@ -475,16 +378,10 @@ DELETE FROM user_links
|
||||
DELETE FROM external_auth_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL;
|
||||
DELETE FROM user_chat_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
DELETE FROM user_ai_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
DELETE FROM user_secrets
|
||||
WHERE value_key_id IS NOT NULL;
|
||||
UPDATE chat_providers
|
||||
SET api_key = '',
|
||||
api_key_key_id = NULL
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
UPDATE ai_providers
|
||||
SET settings = NULL,
|
||||
settings_key_id = NULL
|
||||
@@ -502,9 +399,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
|
||||
store := database.New(sqlDB)
|
||||
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
|
||||
return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err)
|
||||
}
|
||||
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
|
||||
log.Info(ctx, "deleted encrypted user tokens and AI provider API keys")
|
||||
|
||||
log.Info(ctx, "revoking all active keys")
|
||||
keys, err := store.GetDBCryptKeys(ctx)
|
||||
|
||||
+13
-137
@@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := db.decryptAIProviderKey(&keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
@@ -576,92 +589,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.InsertChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.UpdateChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error {
|
||||
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
||||
}
|
||||
@@ -754,57 +681,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
|
||||
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// decryptMCPServerConfig decrypts all encrypted fields on a
|
||||
// single MCPServerConfig in place.
|
||||
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
||||
|
||||
@@ -1281,6 +1281,18 @@ func TestAIProviderKeys(t *testing.T) {
|
||||
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
|
||||
})
|
||||
|
||||
t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey)
|
||||
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
|
||||
})
|
||||
|
||||
t.Run("DeleteAIProviderKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
@@ -1558,113 +1570,6 @@ func TestMCPServerUserTokens(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserChatProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
//nolint:gosec // test credentials
|
||||
initialAPIKey = "sk-initial-api-key-value"
|
||||
//nolint:gosec // test credentials
|
||||
updatedAPIKey = "sk-updated-api-key-value"
|
||||
)
|
||||
|
||||
insertProviderAndKey := func(
|
||||
t *testing.T,
|
||||
crypt *dbCrypt,
|
||||
ciphers []Cipher,
|
||||
) (database.ChatProvider, database.UserChatProviderKey) {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{
|
||||
AllowUserApiKey: true,
|
||||
}, func(params *database.InsertChatProviderParams) {
|
||||
params.APIKey = ""
|
||||
})
|
||||
|
||||
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: initialAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialAPIKey, key.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
|
||||
return provider, key
|
||||
}
|
||||
|
||||
getUserChatProviderKey := func(t *testing.T, store interface {
|
||||
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
|
||||
}, userID uuid.UUID, providerID uuid.UUID,
|
||||
) database.UserChatProviderKey {
|
||||
t.Helper()
|
||||
keys, err := store.GetUserChatProviderKeys(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, providerID, keys[0].ChatProviderID)
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, key.ID, got.ID)
|
||||
require.Equal(t, initialAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
|
||||
})
|
||||
|
||||
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, crypt, ciphers := setup(t)
|
||||
_, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, key.ID, keys[0].ID)
|
||||
require.Equal(t, initialAPIKey, keys[0].APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
|
||||
})
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: key.UserID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: updatedAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.ID, updated.ID)
|
||||
require.Equal(t, key.CreatedAt, updated.CreatedAt)
|
||||
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
|
||||
require.Equal(t, updatedAPIKey, updated.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, updatedAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, updatedAPIKey, keys[0].APIKey)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
Reference in New Issue
Block a user