feat: remove legacy chat provider tables (#25416)

This commit is contained in:
Michael Suchacz
2026-05-22 09:50:01 +02:00
committed by GitHub
parent ddec110b0e
commit ca1f6b19a2
46 changed files with 1270 additions and 3505 deletions
+8 -16
View File
@@ -47,10 +47,9 @@ func FakeOpenAICompatProviderAPIKeys(t testing.TB) chatprovider.ProviderAPIKeys
}
// CreateOpenAICompatChatModelConfig creates the default provider and model
// config used by chat runtime tests. Tests that create chats should also set
// Options.ChatProviderAPIKeys, usually via FakeOpenAICompatProviderAPIKeys, so
// background chat work routes to a local provider until coderd closes. baseURL,
// when non-empty, is stored on the provider config.
// config used by chat runtime tests. Tests can pass a baseURL to route chat work
// to a specific local provider. If baseURL is empty, this helper starts a fake
// OpenAI-compatible provider.
func CreateOpenAICompatChatModelConfig(
t testing.TB,
client *codersdk.ExperimentalClient,
@@ -58,26 +57,19 @@ func CreateOpenAICompatChatModelConfig(
) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: TestChatProviderOpenAICompat,
APIKey: TestChatProviderAPIKey,
BaseURL: baseURL,
})
require.NoError(t, err)
aiProviderBaseURL := baseURL
if aiProviderBaseURL == "" {
aiProviderBaseURL = "https://api.example.com/v1"
if baseURL == "" {
baseURL = chattest.OpenAI(t)
}
ctx := testutil.Context(t, testutil.WaitLong)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderType(TestChatProviderOpenAICompat),
Name: "test-" + uuid.NewString(),
BaseURL: aiProviderBaseURL,
BaseURL: baseURL,
Enabled: true,
APIKeys: []string{TestChatProviderAPIKey},
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
+1 -3
View File
@@ -12,10 +12,9 @@ const (
CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices
CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsAiProviderRequiredWhenActive CheckConstraint = "chat_model_configs_ai_provider_required_when_active" // chat_model_configs
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
@@ -45,7 +44,6 @@ const (
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills
CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills
CheckUserSkillsNameFormat CheckConstraint = "user_skills_name_format" // user_skills
+14 -107
View File
@@ -1967,6 +1967,13 @@ func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) e
return q.db.DeleteChatModelConfigByID(ctx, id)
}
func (q *querier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAIProvider); err != nil {
return err
}
return q.db.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID)
}
func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -1974,13 +1981,6 @@ func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider
return q.db.DeleteChatModelConfigsByProvider(ctx, provider)
}
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatProviderByID(ctx, id)
}
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
@@ -2313,17 +2313,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
@@ -2617,6 +2606,13 @@ func (q *querier) GetAIProviderKeysByProviderID(ctx context.Context, providerID
return q.db.GetAIProviderKeysByProviderID(ctx, providerID)
}
func (q *querier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
return nil, err
}
return q.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
}
func (q *querier) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
return nil, err
@@ -3128,41 +3124,6 @@ func (q *querier) GetChatPlanModeInstructions(ctx context.Context) (string, erro
return q.db.GetChatPlanModeInstructions(ctx)
}
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByID(ctx, id)
}
func (q *querier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByIDForUpdate(ctx, id)
}
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByProvider(ctx, provider)
}
func (q *querier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.GetChatProviderByProviderForUpdate(ctx, provider)
}
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetChatProviders(ctx)
}
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
_, err := q.GetChatByID(ctx, chatID)
if err != nil {
@@ -3403,13 +3364,6 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch
return q.db.GetEnabledChatModelConfigs(ctx)
}
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
@@ -4661,17 +4615,6 @@ func (q *querier) GetUserChatPersonalModelOverride(ctx context.Context, arg data
return q.db.GetUserChatPersonalModelOverride(ctx, arg)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return nil, err
}
return q.db.GetUserChatProviderKeys(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
@@ -5525,13 +5468,6 @@ func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.Insert
return q.db.InsertChatModelConfig(ctx, arg)
}
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.InsertChatProvider(ctx, arg)
}
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
@@ -6749,13 +6685,6 @@ func (q *querier) UpdateChatPlanModeByID(ctx context.Context, arg database.Updat
return q.db.UpdateChatPlanModeByID(ctx, arg)
}
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
}
return q.db.UpdateChatProvider(ctx, arg)
}
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
// UpdateChatStatus is used by the chat processor to change chat status.
// It should be called with system context.
@@ -7429,17 +7358,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
}
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpdateUserChatProviderKey(ctx, arg)
}
func (q *querier) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
user, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
@@ -8371,17 +8289,6 @@ func (q *querier) UpsertUserChatPersonalModelOverride(ctx context.Context, arg d
return q.db.UpsertUserChatPersonalModelOverride(ctx, arg)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpsertUserChatProviderKey(ctx, arg)
}
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+18 -90
View File
@@ -557,10 +557,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().DeleteChatModelConfigsByProvider(gomock.Any(), providerName).Return(nil).AnyTimes()
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
s.Run("DeleteChatModelConfigsByAIProviderID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
providerID := uuid.New()
dbm.EXPECT().DeleteChatModelConfigsByAIProviderID(gomock.Any(), providerID).Return(nil).AnyTimes()
check.Args(providerID).Asserts(rbac.ResourceAIProvider, policy.ActionDelete)
}))
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -972,34 +972,7 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerName := "test-provider"
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
}))
s.Run("GetChatProviderByProviderForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerName := "test-provider"
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
dbm.EXPECT().GetChatProviderByProviderForUpdate(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes()
@@ -1112,12 +1085,7 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
}))
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
threshold := dbtime.Now()
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
@@ -1166,17 +1134,7 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertChatProviderParams{
Provider: "test-provider",
DisplayName: "Test Provider",
APIKey: "test-api-key",
Enabled: true,
}
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
@@ -1303,17 +1261,7 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
arg := database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: "Updated Provider",
APIKey: "updated-api-key",
Enabled: true,
}
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
}))
s.Run("UpdateChatPinOrder", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatPinOrderParams{
@@ -2933,36 +2881,7 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
}))
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
}))
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
}))
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("GetUserAIProviderKeyByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.GetUserAIProviderKeyByProviderIDParams{UserID: u.ID, AIProviderID: uuid.New()}
@@ -6582,6 +6501,15 @@ func (s *MethodTestSuite) TestAIBridge() {
dbm.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), provider.ID).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB})
}))
s.Run("GetAIProviderKeysByProviderIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.AIProvider{})
providerB := testutil.Fake(s.T(), faker, database.AIProvider{})
providerIDs := []uuid.UUID{providerA.ID, providerB.ID}
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerA.ID})
keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerB.ID})
dbm.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), providerIDs).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes()
check.Args(providerIDs).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB})
}))
s.Run("GetAIProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{})
keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{})
+53 -5
View File
@@ -149,8 +149,29 @@ const (
func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelConfig, munge ...func(*database.InsertChatModelConfigParams)) database.ChatModelConfig {
t.Helper()
providerName := takeFirst(seed.Provider, "openai")
aiProviderID := seed.AIProviderID
if !aiProviderID.Valid {
providers, err := db.GetAIProviders(genCtx, database.GetAIProvidersParams{IncludeDisabled: true})
require.NoError(t, err, "get ai providers")
var provider database.AIProvider
for _, candidate := range providers {
if candidate.Type != database.AIProviderType(providerName) {
continue
}
if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) {
provider = candidate
}
}
if provider.ID == uuid.Nil {
provider = AIProvider(t, db, database.AIProvider{
Type: database.AIProviderType(providerName),
})
}
aiProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true}
}
params := database.InsertChatModelConfigParams{
Provider: takeFirst(seed.Provider, "openai"),
Provider: providerName,
Model: takeFirst(seed.Model, "gpt-4o-mini"),
DisplayName: takeFirst(seed.DisplayName, "Test Model"),
CreatedBy: seed.CreatedBy,
@@ -160,7 +181,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon
ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit),
CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold),
Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)),
AIProviderID: seed.AIProviderID,
AIProviderID: aiProviderID,
}
for _, fn := range munge {
fn(&params)
@@ -263,9 +284,36 @@ func ChatProvider(t testing.TB, db database.Store, seed database.ChatProvider, m
for _, fn := range munge {
fn(&params)
}
provider, err := db.InsertChatProvider(genCtx, params)
require.NoError(t, err, "insert chat provider")
return provider
provider := AIProvider(t, db, database.AIProvider{
Type: database.AIProviderType(params.Provider),
Name: "test-" + uuid.NewString(),
DisplayName: sql.NullString{String: params.DisplayName, Valid: params.DisplayName != ""},
BaseUrl: params.BaseUrl,
}, func(p *database.InsertAIProviderParams) {
p.Enabled = params.Enabled
})
if params.APIKey != "" {
AIProviderKey(t, db, database.AIProviderKey{
ProviderID: provider.ID,
APIKey: params.APIKey,
ApiKeyKeyID: params.ApiKeyKeyID,
})
}
return database.ChatProvider{
ID: provider.ID,
Provider: params.Provider,
DisplayName: params.DisplayName,
APIKey: params.APIKey,
BaseUrl: params.BaseUrl,
ApiKeyKeyID: params.ApiKeyKeyID,
CreatedBy: params.CreatedBy,
Enabled: params.Enabled,
CentralApiKeyEnabled: params.CentralApiKeyEnabled,
AllowUserApiKey: params.AllowUserApiKey,
AllowCentralApiKeyFallback: params.AllowCentralApiKeyFallback,
CreatedAt: provider.CreatedAt,
UpdatedAt: provider.UpdatedAt,
}
}
func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerConfig) database.MCPServerConfig {
+16 -104
View File
@@ -465,6 +465,14 @@ func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uui
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID)
m.queryLatencies.WithLabelValues("DeleteChatModelConfigsByAIProviderID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigsByAIProviderID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigsByProvider(ctx, provider)
@@ -473,14 +481,6 @@ func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context,
return r0
}
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
start := time.Now()
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
@@ -809,14 +809,6 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
return r0
}
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
start := time.Now()
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
return r0
}
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
@@ -1089,6 +1081,14 @@ func (m queryMetricsStore) GetAIProviderKeysByProviderID(ctx context.Context, pr
return r0, r1
}
func (m queryMetricsStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviderKeysByProviderIDs(ctx, providerIds)
m.queryLatencies.WithLabelValues("GetAIProviderKeysByProviderIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeysByProviderIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviders(ctx, arg)
@@ -1537,46 +1537,6 @@ func (m queryMetricsStore) GetChatPlanModeInstructions(ctx context.Context) (str
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByIDForUpdate(ctx, id)
m.queryLatencies.WithLabelValues("GetChatProviderByIDForUpdate").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByIDForUpdate").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByProviderForUpdate(ctx, provider)
m.queryLatencies.WithLabelValues("GetChatProviderByProviderForUpdate").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProviderForUpdate").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
@@ -1833,14 +1793,6 @@ func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]da
return r0, r1
}
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledChatProviders(ctx)
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
return r0, r1
}
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
@@ -3033,14 +2985,6 @@ func (m queryMetricsStore) GetUserChatPersonalModelOverride(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
@@ -3825,14 +3769,6 @@ func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg databa
return r0, r1
}
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.InsertChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
@@ -4881,14 +4817,6 @@ func (m queryMetricsStore) UpdateChatPlanModeByID(ctx context.Context, arg datab
return r0, r1
}
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
@@ -5321,14 +5249,6 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserCodeDiffDisplayMode(ctx, arg)
@@ -6105,14 +6025,6 @@ func (m queryMetricsStore) UpsertUserChatPersonalModelOverride(ctx context.Conte
return r0
}
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
start := time.Now()
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
+29 -193
View File
@@ -760,6 +760,20 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id)
}
// DeleteChatModelConfigsByAIProviderID mocks base method.
func (m *MockStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatModelConfigsByAIProviderID", ctx, aiProviderID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatModelConfigsByAIProviderID indicates an expected call of DeleteChatModelConfigsByAIProviderID.
func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByAIProviderID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByAIProviderID), ctx, aiProviderID)
}
// DeleteChatModelConfigsByProvider mocks base method.
func (m *MockStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
m.ctrl.T.Helper()
@@ -774,20 +788,6 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByProvider(ctx, provider
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByProvider", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByProvider), ctx, provider)
}
// DeleteChatProviderByID mocks base method.
func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID.
func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id)
}
// DeleteChatQueuedMessage mocks base method.
func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
m.ctrl.T.Helper()
@@ -1376,20 +1376,6 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
}
// DeleteUserChatProviderKey mocks base method.
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -1889,6 +1875,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderID(ctx, providerID a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderID), ctx, providerID)
}
// GetAIProviderKeysByProviderIDs mocks base method.
func (m *MockStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAIProviderKeysByProviderIDs", ctx, providerIds)
ret0, _ := ret[0].([]database.AIProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAIProviderKeysByProviderIDs indicates an expected call of GetAIProviderKeysByProviderIDs.
func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderIDs(ctx, providerIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderIDs", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderIDs), ctx, providerIds)
}
// GetAIProviders mocks base method.
func (m *MockStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
m.ctrl.T.Helper()
@@ -2849,81 +2850,6 @@ func (mr *MockStoreMockRecorder) GetChatPlanModeInstructions(ctx any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).GetChatPlanModeInstructions), ctx)
}
// GetChatProviderByID mocks base method.
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByID indicates an expected call of GetChatProviderByID.
func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id)
}
// GetChatProviderByIDForUpdate mocks base method.
func (m *MockStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByIDForUpdate", ctx, id)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByIDForUpdate indicates an expected call of GetChatProviderByIDForUpdate.
func (mr *MockStoreMockRecorder) GetChatProviderByIDForUpdate(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByIDForUpdate), ctx, id)
}
// GetChatProviderByProvider mocks base method.
func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider.
func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider)
}
// GetChatProviderByProviderForUpdate mocks base method.
func (m *MockStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviderByProviderForUpdate", ctx, provider)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviderByProviderForUpdate indicates an expected call of GetChatProviderByProviderForUpdate.
func (mr *MockStoreMockRecorder) GetChatProviderByProviderForUpdate(ctx, provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProviderForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProviderForUpdate), ctx, provider)
}
// GetChatProviders mocks base method.
func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatProviders", ctx)
ret0, _ := ret[0].([]database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatProviders indicates an expected call of GetChatProviders.
func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx)
}
// GetChatQueuedMessages mocks base method.
func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
m.ctrl.T.Helper()
@@ -3404,21 +3330,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx)
}
// GetEnabledChatProviders mocks base method.
func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx)
ret0, _ := ret[0].([]database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders.
func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
}
// GetEnabledMCPServerConfigs mocks base method.
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
@@ -5684,21 +5595,6 @@ func (mr *MockStoreMockRecorder) GetUserChatPersonalModelOverride(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).GetUserChatPersonalModelOverride), ctx, arg)
}
// GetUserChatProviderKeys mocks base method.
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
ret0, _ := ret[0].([]database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
@@ -7183,21 +7079,6 @@ func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg)
}
// InsertChatProvider mocks base method.
func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatProvider indicates an expected call of InsertChatProvider.
func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg)
}
// InsertChatQueuedMessage mocks base method.
func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
m.ctrl.T.Helper()
@@ -9234,21 +9115,6 @@ func (mr *MockStoreMockRecorder) UpdateChatPlanModeByID(ctx, arg any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPlanModeByID", reflect.TypeOf((*MockStore)(nil).UpdateChatPlanModeByID), ctx, arg)
}
// UpdateChatProvider mocks base method.
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg)
ret0, _ := ret[0].(database.ChatProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatProvider indicates an expected call of UpdateChatProvider.
func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg)
}
// UpdateChatStatus mocks base method.
func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
m.ctrl.T.Helper()
@@ -10035,21 +9901,6 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
}
// UpdateUserChatProviderKey mocks base method.
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
}
// UpdateUserCodeDiffDisplayMode mocks base method.
func (m *MockStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
m.ctrl.T.Helper()
@@ -11446,21 +11297,6 @@ func (mr *MockStoreMockRecorder) UpsertUserChatPersonalModelOverride(ctx, arg an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertUserChatPersonalModelOverride), ctx, arg)
}
// UpsertUserChatProviderKey mocks base method.
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
}
// UpsertWebpushVAPIDKeys mocks base method.
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
m.ctrl.T.Helper()
+1 -60
View File
@@ -1531,30 +1531,11 @@ CREATE TABLE chat_model_configs (
compression_threshold integer NOT NULL,
options jsonb DEFAULT '{}'::jsonb NOT NULL,
ai_provider_id uuid,
CONSTRAINT chat_model_configs_ai_provider_required_when_active CHECK (((deleted = true) OR (ai_provider_id IS NOT NULL))),
CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))),
CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0))
);
CREATE TABLE chat_providers (
id uuid DEFAULT gen_random_uuid() NOT NULL,
provider text NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
api_key text DEFAULT ''::text NOT NULL,
api_key_key_id text,
created_by uuid,
enabled boolean DEFAULT true NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
base_url text DEFAULT ''::text NOT NULL,
central_api_key_enabled boolean DEFAULT true NOT NULL,
allow_user_api_key boolean DEFAULT false NOT NULL,
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
);
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
CREATE TABLE chat_queued_messages (
id bigint NOT NULL,
chat_id uuid NOT NULL,
@@ -3040,17 +3021,6 @@ COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to a
COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.';
CREATE TABLE user_chat_provider_keys (
id uuid DEFAULT gen_random_uuid() NOT NULL,
user_id uuid NOT NULL,
chat_provider_id uuid NOT NULL,
api_key text NOT NULL,
api_key_key_id text,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
);
CREATE TABLE user_configs (
user_id uuid NOT NULL,
key character varying(256) NOT NULL,
@@ -3667,12 +3637,6 @@ ALTER TABLE ONLY chat_messages
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
@@ -3889,12 +3853,6 @@ ALTER TABLE ONLY user_ai_provider_keys
ALTER TABLE ONLY user_ai_provider_keys
ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
@@ -4112,8 +4070,6 @@ CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING b
CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
@@ -4444,12 +4400,6 @@ ALTER TABLE ONLY chat_model_configs
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY chat_providers
ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -4687,15 +4637,6 @@ ALTER TABLE ONLY user_ai_provider_keys
ALTER TABLE ONLY user_ai_provider_keys
ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -24,8 +24,6 @@ const (
ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id);
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
@@ -105,9 +103,6 @@ const (
ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE;
ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -0,0 +1,44 @@
package database
import (
"database/sql"
"time"
"github.com/google/uuid"
)
// ChatProvider is the fixture shape accepted by dbgen.ChatProvider.
//
//nolint:revive
type ChatProvider struct {
ID uuid.UUID
Provider string
DisplayName string
APIKey string
BaseUrl string
ApiKeyKeyID sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
CreatedBy uuid.NullUUID
Enabled bool
CentralApiKeyEnabled bool
AllowUserApiKey bool
AllowCentralApiKeyFallback bool
}
// InsertChatProviderParams is the callback parameter shape accepted by
// dbgen.ChatProvider.
//
//nolint:revive
type InsertChatProviderParams struct {
Provider string
DisplayName string
APIKey string
BaseUrl string
ApiKeyKeyID sql.NullString
CreatedBy uuid.NullUUID
Enabled bool
CentralApiKeyEnabled bool
AllowUserApiKey bool
AllowCentralApiKeyFallback bool
}
@@ -1,8 +1,15 @@
DROP TABLE IF EXISTS user_chat_provider_keys;
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
DO $$
BEGIN
IF to_regclass('chat_providers') IS NULL THEN
RETURN;
END IF;
ALTER TABLE chat_providers
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
ALTER TABLE chat_providers
DROP COLUMN IF EXISTS central_api_key_enabled,
DROP COLUMN IF EXISTS allow_user_api_key,
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
END $$;
@@ -1,27 +1,34 @@
-- Restore placeholder provider rows before re-adding the provider FK.
--
-- The companion up migration dropped chat_model_configs.provider's foreign
-- key, so historical model-config rows can outlive a deleted provider row.
-- These backfilled providers are deliberately disabled stubs with empty
-- credential fields, which lets rollback restore referential integrity
-- without re-enabling a provider. This insert depends on the current
-- provider whitelist still admitting every historical
-- chat_model_configs.provider value, and on the omitted columns keeping
-- compatible defaults. Operators restoring a real provider should update the
-- stub row, including credential-policy flags such as
-- central_api_key_enabled, before enabling it, rather than insert a second
-- row with the same provider name.
INSERT INTO chat_providers (provider, enabled)
SELECT DISTINCT
DO $$
BEGIN
IF to_regclass('chat_providers') IS NULL THEN
RETURN;
END IF;
-- Restore placeholder provider rows before re-adding the provider FK.
--
-- The companion up migration dropped chat_model_configs.provider's foreign
-- key, so historical model-config rows can outlive a deleted provider row.
-- These backfilled providers are deliberately disabled stubs with empty
-- credential fields, which lets rollback restore referential integrity
-- without re-enabling a provider. This insert depends on the current
-- provider whitelist still admitting every historical
-- chat_model_configs.provider value, and on the omitted columns keeping
-- compatible defaults. Operators restoring a real provider should update the
-- stub row, including credential-policy flags such as
-- central_api_key_enabled, before enabling it, rather than insert a second
-- row with the same provider name.
INSERT INTO chat_providers (provider, enabled)
SELECT DISTINCT
cmc.provider,
FALSE
FROM
FROM
chat_model_configs cmc
LEFT JOIN
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider
WHERE
WHERE
cp.provider IS NULL;
ALTER TABLE chat_model_configs
ALTER TABLE chat_model_configs
ADD CONSTRAINT chat_model_configs_provider_fkey
FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
END $$;
@@ -1,17 +1,10 @@
WITH migrated_provider_ids AS (
SELECT id
FROM chat_providers
UNION
SELECT id
FROM ai_providers
WHERE name LIKE 'agents-%'
AND deleted = TRUE
)
UPDATE chat_model_configs
SET ai_provider_id = NULL
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
DO $$
BEGIN
IF to_regclass('chat_providers') IS NULL THEN
RETURN;
END IF;
WITH migrated_provider_ids AS (
WITH migrated_provider_ids AS (
SELECT id
FROM chat_providers
UNION
@@ -19,11 +12,12 @@ WITH migrated_provider_ids AS (
FROM ai_providers
WHERE name LIKE 'agents-%'
AND deleted = TRUE
)
DELETE FROM user_ai_provider_keys
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
)
UPDATE chat_model_configs
SET ai_provider_id = NULL
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
WITH migrated_provider_ids AS (
WITH migrated_provider_ids AS (
SELECT id
FROM chat_providers
UNION
@@ -31,11 +25,11 @@ WITH migrated_provider_ids AS (
FROM ai_providers
WHERE name LIKE 'agents-%'
AND deleted = TRUE
)
DELETE FROM ai_provider_keys
WHERE provider_id IN (SELECT id FROM migrated_provider_ids);
)
DELETE FROM user_ai_provider_keys
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
WITH migrated_provider_ids AS (
WITH migrated_provider_ids AS (
SELECT id
FROM chat_providers
UNION
@@ -43,6 +37,19 @@ WITH migrated_provider_ids AS (
FROM ai_providers
WHERE name LIKE 'agents-%'
AND deleted = TRUE
)
DELETE FROM ai_providers
WHERE id IN (SELECT id FROM migrated_provider_ids);
)
DELETE FROM ai_provider_keys
WHERE provider_id IN (SELECT id FROM migrated_provider_ids);
WITH migrated_provider_ids AS (
SELECT id
FROM chat_providers
UNION
SELECT id
FROM ai_providers
WHERE name LIKE 'agents-%'
AND deleted = TRUE
)
DELETE FROM ai_providers
WHERE id IN (SELECT id FROM migrated_provider_ids);
END $$;
@@ -0,0 +1,3 @@
-- no-op. Legacy chat provider tables are intentionally not recreated from AI
-- provider definitions. Rolling back past this migration is not reversible at
-- the schema level.
@@ -0,0 +1,140 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM chat_providers cp
JOIN ai_providers ap ON ap.name = 'agents-' || cp.provider
WHERE ap.deleted = FALSE
AND ap.id != cp.id
) THEN
RAISE EXCEPTION 'cannot finalize chat provider migration because a live agents-* AI provider name already exists';
END IF;
END $$;
INSERT INTO ai_providers (
id,
type,
name,
display_name,
enabled,
base_url,
created_at,
updated_at
)
SELECT
cp.id,
cp.provider::ai_provider_type,
'agents-' || cp.provider,
NULLIF(cp.display_name, ''),
cp.enabled,
cp.base_url,
cp.created_at,
cp.updated_at
FROM chat_providers cp
WHERE NOT EXISTS (
SELECT 1
FROM ai_providers ap
WHERE ap.id = cp.id
);
UPDATE ai_providers ap
SET
type = cp.provider::ai_provider_type,
name = 'agents-' || cp.provider,
display_name = NULLIF(cp.display_name, ''),
enabled = cp.enabled,
deleted = FALSE,
base_url = cp.base_url,
updated_at = GREATEST(cp.updated_at, ap.updated_at)
FROM chat_providers cp
WHERE ap.id = cp.id
AND (cp.updated_at > ap.updated_at OR ap.deleted);
DELETE FROM ai_provider_keys apk
USING chat_providers cp
WHERE cp.id = apk.provider_id
AND cp.api_key = ''
AND cp.updated_at > apk.updated_at;
WITH runtime_provider_keys AS (
SELECT DISTINCT ON (apk.provider_id)
apk.id,
apk.provider_id
FROM ai_provider_keys apk
JOIN chat_providers cp ON cp.id = apk.provider_id
WHERE cp.api_key != ''
ORDER BY
apk.provider_id ASC,
apk.created_at ASC,
apk.id ASC
)
UPDATE ai_provider_keys apk
SET
api_key = cp.api_key,
api_key_key_id = cp.api_key_key_id,
updated_at = cp.updated_at
FROM runtime_provider_keys rpk
JOIN chat_providers cp ON cp.id = rpk.provider_id
WHERE apk.id = rpk.id
AND cp.updated_at > apk.updated_at;
INSERT INTO ai_provider_keys (
id,
provider_id,
api_key,
api_key_key_id,
created_at,
updated_at
)
SELECT
gen_random_uuid(),
cp.id,
cp.api_key,
cp.api_key_key_id,
cp.updated_at,
cp.updated_at
FROM chat_providers cp
WHERE cp.api_key != ''
AND NOT EXISTS (
SELECT 1
FROM ai_provider_keys apk
WHERE apk.provider_id = cp.id
);
INSERT INTO user_ai_provider_keys (
id,
user_id,
ai_provider_id,
api_key,
api_key_key_id,
created_at,
updated_at
)
SELECT
ucpk.id,
ucpk.user_id,
ucpk.chat_provider_id,
ucpk.api_key,
ucpk.api_key_key_id,
ucpk.created_at,
ucpk.updated_at
FROM user_chat_provider_keys ucpk
ON CONFLICT (user_id, ai_provider_id) DO UPDATE
SET
api_key = EXCLUDED.api_key,
api_key_key_id = EXCLUDED.api_key_key_id,
updated_at = EXCLUDED.updated_at
WHERE user_ai_provider_keys.updated_at < EXCLUDED.updated_at;
UPDATE chat_model_configs cmc
SET ai_provider_id = cp.id
FROM chat_providers cp
WHERE cmc.provider = cp.provider
AND cmc.ai_provider_id IS NULL;
ALTER TABLE chat_model_configs
ADD CONSTRAINT chat_model_configs_ai_provider_required_when_active
CHECK (deleted = TRUE OR ai_provider_id IS NOT NULL);
DROP TABLE IF EXISTS user_chat_provider_keys;
DROP TABLE IF EXISTS chat_providers;
-27
View File
@@ -4682,23 +4682,6 @@ type ChatModelConfig struct {
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
}
type ChatProvider struct {
ID uuid.UUID `db:"id" json:"id"`
Provider string `db:"provider" json:"provider"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
}
type ChatQueuedMessage struct {
ID int64 `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
@@ -5706,16 +5689,6 @@ type UserAiProviderKey struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type UserChatProviderKey struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type UserConfig struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Key string `db:"key" json:"key"`
+4 -13
View File
@@ -122,8 +122,8 @@ type sqlcQuerier interface {
// archive-cleanup retry).
DeleteChatDebugDataByChatID(ctx context.Context, arg DeleteChatDebugDataByChatIDParams) (int64, error)
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error
DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
@@ -196,7 +196,6 @@ type sqlcQuerier interface {
DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error
DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (UserSecret, error)
DeleteUserSkillByUserIDAndName(ctx context.Context, arg DeleteUserSkillByUserIDAndNameParams) (UserSkill, error)
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
@@ -272,6 +271,9 @@ type sqlcQuerier interface {
// key per provider; multiple keys are stored to support future
// failover and rotation flows.
GetAIProviderKeysByProviderID(ctx context.Context, providerID uuid.UUID) ([]AIProviderKey, error)
// Returns all keys for the requested providers, ordered by provider then created_at ASC
// so callers can select the oldest non-empty key per provider without issuing N queries.
GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error)
// Returns AI provider rows. Soft-deleted and disabled rows are excluded
// unless include_deleted or include_disabled is set.
GetAIProviders(ctx context.Context, arg GetAIProvidersParams) ([]AIProvider, error)
@@ -384,11 +386,6 @@ type sqlcQuerier interface {
// personal chat model overrides. It defaults to false when unset.
GetChatPersonalModelOverridesEnabled(ctx context.Context) (bool, error)
GetChatPlanModeInstructions(ctx context.Context) (string, error)
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error)
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error)
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
// Returns the chat retention period in days. Chats archived longer
// than this and orphaned chat files older than this are purged by
@@ -452,7 +449,6 @@ type sqlcQuerier interface {
// Check both to ensure the selected config is actually usable.
GetEnabledChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
@@ -760,7 +756,6 @@ type sqlcQuerier interface {
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
GetUserChatPersonalModelOverride(ctx context.Context, arg GetUserChatPersonalModelOverrideParams) (string, error)
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
// Returns the total spend for a user in the given period.
// When organization_id is NULL, spend across all organizations is
// returned (global behavior). Otherwise only spend within the
@@ -932,7 +927,6 @@ type sqlcQuerier interface {
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error)
InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error)
@@ -1214,7 +1208,6 @@ type sqlcQuerier interface {
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error
UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error)
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error)
UpdateChatTitleByID(ctx context.Context, arg UpdateChatTitleByIDParams) (Chat, error)
@@ -1283,7 +1276,6 @@ type sqlcQuerier interface {
UpdateUserAgentChatSendShortcut(ctx context.Context, arg UpdateUserAgentChatSendShortcutParams) (string, error)
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
UpdateUserCodeDiffDisplayMode(ctx context.Context, arg UpdateUserCodeDiffDisplayModeParams) (string, error)
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
@@ -1408,7 +1400,6 @@ type sqlcQuerier interface {
UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error)
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
UpsertUserChatPersonalModelOverride(ctx context.Context, arg UpsertUserChatPersonalModelOverrideParams) error
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
+88 -73
View File
@@ -10609,20 +10609,12 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
}, func(params *database.InsertChatModelConfigParams) {
params.Enabled = false
})
legacyProvider := dbgen.ChatProvider(t, store, database.ChatProvider{Provider: "google"})
legacyConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
Provider: legacyProvider.Provider,
Model: "google-model-" + uuid.NewString(),
})
configs, err := store.GetEnabledChatModelConfigs(ctx)
require.NoError(t, err)
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == enabledConfig.ID
}))
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == legacyConfig.ID
}))
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == disabledProviderConfig.ID
}))
@@ -10636,6 +10628,46 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledProviderConfig.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledModelConfig.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
}
func insertChatModelConfigForTest(
ctx context.Context,
t testing.TB,
store database.Store,
params database.InsertChatModelConfigParams,
) (database.ChatModelConfig, error) {
t.Helper()
if params.AIProviderID.Valid {
return store.InsertChatModelConfig(ctx, params)
}
providerName := params.Provider
if providerName == "" {
providerName = "openai"
params.Provider = providerName
}
providers, err := store.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
if err != nil {
return database.ChatModelConfig{}, err
}
var provider database.AIProvider
for _, candidate := range providers {
if candidate.Type != database.AIProviderType(providerName) {
continue
}
if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) {
provider = candidate
}
}
if provider.ID == uuid.Nil {
provider = dbgen.AIProvider(t, store, database.AIProvider{
Type: database.AIProviderType(providerName),
})
}
params.AIProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true}
return store.InsertChatModelConfig(ctx, params)
}
func TestInsertChatMessages(t *testing.T) {
@@ -10653,7 +10685,7 @@ func TestInsertChatMessages(t *testing.T) {
) database.ChatModelConfig {
t.Helper()
modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelConfig, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: provider,
Model: model,
DisplayName: displayName,
@@ -10681,14 +10713,13 @@ func TestInsertChatMessages(t *testing.T) {
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
provider := "openai"
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: provider,
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelConfigA := insertModelConfig(
t,
@@ -10850,18 +10881,21 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
// A chat_providers row is required as a FK for model configs.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
// An AI provider row is required as a FK for model configs.
provider := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "test-" + uuid.NewString(),
DisplayName: sql.NullString{String: "OpenAI", Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: provider.ID,
APIKey: "test-key",
})
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
@@ -11227,16 +11261,15 @@ func TestGetPRInsights(t *testing.T) {
user := dbgen.User(t, store, database.User{})
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
mc, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "claude-4",
DisplayName: "Claude 4",
@@ -11683,7 +11716,7 @@ func TestGetPRInsights(t *testing.T) {
store, userID, _, orgID := setupChatInfra(t)
const modelName = "claude-4.1"
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
emptyDisplayModel, err := insertChatModelConfigForTest(context.Background(), t, store, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: modelName,
DisplayName: "",
@@ -11791,16 +11824,15 @@ func TestChatPinOrderQueries(t *testing.T) {
// Use background context for fixture setup so the
// timed test context doesn't tick during DB init.
bg := context.Background()
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
@@ -11972,16 +12004,15 @@ func TestChatPinOrderConstraints(t *testing.T) {
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
bg := context.Background()
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
@@ -12065,16 +12096,15 @@ func TestChatLabels(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
@@ -12365,16 +12395,15 @@ func TestUpdateChatLastTurnSummary(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
@@ -12466,16 +12495,15 @@ func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
providerName := "openai"
modelName := "debug-model-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -12659,16 +12687,15 @@ func TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls(t *te
providerName := "openai"
modelName := "debug-model-step-boundaries-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -12917,16 +12944,15 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) {
providerName := "openai"
modelName := "debug-model-finalize-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13356,16 +13382,15 @@ func TestChatDebugSQLGuards(t *testing.T) {
providerName := "openai"
modelName := "debug-model-guards-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13490,16 +13515,15 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
providerName := "openai"
modelName := "debug-model-coalesce-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13605,16 +13629,15 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
providerName := "openai"
modelName := "debug-step-coalesce-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13730,16 +13753,15 @@ func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
providerName := "openai"
modelName := "debug-model-null-msg-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13828,16 +13850,15 @@ func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testi
providerName := "openai"
modelName := "debug-model-started-before-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -13940,16 +13961,15 @@ func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T)
providerName := "openai"
modelName := "debug-model-by-chat-started-before-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
@@ -14029,17 +14049,13 @@ func TestGetChatsFilter(t *testing.T) {
user := dbgen.User(t, store, database.User{})
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
provider := dbgen.AIProviderWithOptionalKey(t, store, database.AIProvider{
Type: database.AiProviderTypeOpenai,
}, "test-key")
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Model: "test-model-" + uuid.NewString(),
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
@@ -14265,16 +14281,15 @@ func TestChatHasUnread(t *testing.T) {
user := dbgen.User(t, store, database.User{})
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
dbgen.ChatProvider(t, store, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model-" + uuid.NewString(),
DisplayName: "Test Model",
+68 -497
View File
@@ -276,6 +276,51 @@ func (q *sqlQuerier) GetAIProviderKeysByProviderID(ctx context.Context, provider
return items, nil
}
const getAIProviderKeysByProviderIDs = `-- name: GetAIProviderKeysByProviderIDs :many
SELECT
id, provider_id, api_key, api_key_key_id, created_at, updated_at
FROM
ai_provider_keys
WHERE
provider_id = ANY($1::uuid[])
ORDER BY
provider_id ASC,
created_at ASC,
id ASC
`
// Returns all keys for the requested providers, ordered by provider then created_at ASC
// so callers can select the oldest non-empty key per provider without issuing N queries.
func (q *sqlQuerier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) {
rows, err := q.db.QueryContext(ctx, getAIProviderKeysByProviderIDs, pq.Array(providerIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []AIProviderKey
for rows.Next() {
var i AIProviderKey
if err := rows.Scan(
&i.ID,
&i.ProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertAIProviderKey = `-- name: InsertAIProviderKey :one
INSERT INTO ai_provider_keys (
id,
@@ -5020,6 +5065,23 @@ func (q *sqlQuerier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID
return err
}
const deleteChatModelConfigsByAIProviderID = `-- name: DeleteChatModelConfigsByAIProviderID :exec
UPDATE
chat_model_configs
SET
deleted = TRUE,
deleted_at = NOW(),
updated_at = NOW()
WHERE
ai_provider_id = $1::uuid
AND deleted = FALSE
`
func (q *sqlQuerier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteChatModelConfigsByAIProviderID, aiProviderID)
return err
}
const deleteChatModelConfigsByProvider = `-- name: DeleteChatModelConfigsByProvider :exec
UPDATE
chat_model_configs
@@ -5164,18 +5226,14 @@ SELECT
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
FROM
chat_model_configs cmc
LEFT JOIN
JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.id = $1::uuid
AND cmc.deleted = FALSE
AND cmc.enabled = TRUE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
AND ap.enabled = TRUE
AND ap.deleted = FALSE
`
// Providers can be disabled independently of their model configs.
@@ -5209,17 +5267,13 @@ SELECT
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
FROM
chat_model_configs cmc
LEFT JOIN
JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.enabled = TRUE
AND cmc.deleted = FALSE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
AND ap.enabled = TRUE
AND ap.deleted = FALSE
ORDER BY
cmc.provider ASC,
cmc.model ASC,
@@ -5435,369 +5489,6 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo
return i, err
}
const deleteChatProviderByID = `-- name: DeleteChatProviderByID :exec
DELETE FROM
chat_providers
WHERE
id = $1::uuid
`
func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteChatProviderByID, id)
return err
}
const getChatProviderByID = `-- name: GetChatProviderByID :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
id = $1::uuid
`
func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, getChatProviderByID, id)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviderByIDForUpdate = `-- name: GetChatProviderByIDForUpdate :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
id = $1::uuid
FOR UPDATE
`
func (q *sqlQuerier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, getChatProviderByIDForUpdate, id)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
provider = $1::text
`
func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, getChatProviderByProvider, provider)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviderByProviderForUpdate = `-- name: GetChatProviderByProviderForUpdate :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
provider = $1::text
FOR UPDATE
`
func (q *sqlQuerier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, getChatProviderByProviderForUpdate, provider)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviders = `-- name: GetChatProviders :many
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
ORDER BY
provider ASC
`
func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, error) {
rows, err := q.db.QueryContext(ctx, getChatProviders)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatProvider
for rows.Next() {
var i ChatProvider
if err := rows.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
enabled = TRUE
ORDER BY
provider ASC
`
func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) {
rows, err := q.db.QueryContext(ctx, getEnabledChatProviders)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatProvider
for rows.Next() {
var i ChatProvider
if err := rows.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertChatProvider = `-- name: InsertChatProvider :one
INSERT INTO chat_providers (
provider,
display_name,
api_key,
base_url,
api_key_key_id,
created_by,
enabled,
central_api_key_enabled,
allow_user_api_key,
allow_central_api_key_fallback
) VALUES (
$1::text,
$2::text,
$3::text,
$4::text,
$5::text,
$6::uuid,
$7::boolean,
$8::boolean,
$9::boolean,
$10::boolean
)
RETURNING
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
`
type InsertChatProviderParams struct {
Provider string `db:"provider" json:"provider"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
}
func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, insertChatProvider,
arg.Provider,
arg.DisplayName,
arg.APIKey,
arg.BaseUrl,
arg.ApiKeyKeyID,
arg.CreatedBy,
arg.Enabled,
arg.CentralApiKeyEnabled,
arg.AllowUserApiKey,
arg.AllowCentralApiKeyFallback,
)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const updateChatProvider = `-- name: UpdateChatProvider :one
UPDATE
chat_providers
SET
display_name = $1::text,
api_key = $2::text,
base_url = $3::text,
api_key_key_id = $4::text,
enabled = $5::boolean,
central_api_key_enabled = $6::boolean,
allow_user_api_key = $7::boolean,
allow_central_api_key_fallback = $8::boolean,
updated_at = NOW()
WHERE
id = $9::uuid
RETURNING
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
`
type UpdateChatProviderParams struct {
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
Enabled bool `db:"enabled" json:"enabled"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) {
row := q.db.QueryRowContext(ctx, updateChatProvider,
arg.DisplayName,
arg.APIKey,
arg.BaseUrl,
arg.ApiKeyKeyID,
arg.Enabled,
arg.CentralApiKeyEnabled,
arg.AllowUserApiKey,
arg.AllowCentralApiKeyFallback,
arg.ID,
)
var i ChatProvider
err := row.Scan(
&i.ID,
&i.Provider,
&i.DisplayName,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedBy,
&i.Enabled,
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const acquireChats = `-- name: AcquireChats :many
WITH acquired_chats AS (
UPDATE
@@ -27555,126 +27246,6 @@ func (q *sqlQuerier) UpdateUserSkillByUserIDAndName(ctx context.Context, arg Upd
return i, err
}
const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec
DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2
`
type DeleteUserChatProviderKeyParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
}
func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error {
_, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID)
return err
}
const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many
SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC
`
func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) {
rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserChatProviderKey
for rows.Next() {
var i UserChatProviderKey
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one
UPDATE user_chat_provider_keys
SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW()
WHERE user_id = $3 AND chat_provider_id = $4
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
`
type UpdateUserChatProviderKeyParams struct {
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
}
func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) {
row := q.db.QueryRowContext(ctx, updateUserChatProviderKey,
arg.APIKey,
arg.ApiKeyKeyID,
arg.UserID,
arg.ChatProviderID,
)
var i UserChatProviderKey
err := row.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
VALUES ($1, $2, $3, $4::text)
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
api_key = $3,
api_key_key_id = $4::text,
updated_at = NOW()
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
`
type UpsertUserChatProviderKeyParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
}
func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) {
row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey,
arg.UserID,
arg.ChatProviderID,
arg.APIKey,
arg.ApiKeyKeyID,
)
var i UserChatProviderKey
err := row.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const allUserIDs = `-- name: AllUserIDs :many
SELECT DISTINCT id FROM USERS
WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END
@@ -32,6 +32,20 @@ WHERE
ORDER BY
provider_id ASC;
-- name: GetAIProviderKeysByProviderIDs :many
-- Returns all keys for the requested providers, ordered by provider then created_at ASC
-- so callers can select the oldest non-empty key per provider without issuing N queries.
SELECT
*
FROM
ai_provider_keys
WHERE
provider_id = ANY(@provider_ids::uuid[])
ORDER BY
provider_id ASC,
created_at ASC,
id ASC;
-- name: GetAIProviderKeys :many
-- Returns AI provider key rows. By default, only rows whose parent
-- provider is live (deleted = FALSE) are returned, so the API list
+17 -14
View File
@@ -34,17 +34,13 @@ SELECT
cmc.*
FROM
chat_model_configs cmc
LEFT JOIN
JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.enabled = TRUE
AND cmc.deleted = FALSE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
AND ap.enabled = TRUE
AND ap.deleted = FALSE
ORDER BY
cmc.provider ASC,
cmc.model ASC,
@@ -58,18 +54,14 @@ FROM
chat_model_configs cmc
-- Providers can be disabled independently of their model configs.
-- Check both to ensure the selected config is actually usable.
LEFT JOIN
JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.id = @id::uuid
AND cmc.deleted = FALSE
AND cmc.enabled = TRUE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
);
AND ap.enabled = TRUE
AND ap.deleted = FALSE;
-- name: InsertChatModelConfig :one
INSERT INTO chat_model_configs (
@@ -151,3 +143,14 @@ SET
WHERE
provider = @provider::text
AND deleted = FALSE;
-- name: DeleteChatModelConfigsByAIProviderID :exec
UPDATE
chat_model_configs
SET
deleted = TRUE,
deleted_at = NOW(),
updated_at = NOW()
WHERE
ai_provider_id = @ai_provider_id::uuid
AND deleted = FALSE;
-102
View File
@@ -1,102 +0,0 @@
-- name: GetChatProviderByID :one
SELECT
*
FROM
chat_providers
WHERE
id = @id::uuid;
-- name: GetChatProviderByIDForUpdate :one
SELECT
*
FROM
chat_providers
WHERE
id = @id::uuid
FOR UPDATE;
-- name: GetChatProviderByProvider :one
SELECT
*
FROM
chat_providers
WHERE
provider = @provider::text;
-- name: GetChatProviderByProviderForUpdate :one
SELECT
*
FROM
chat_providers
WHERE
provider = @provider::text
FOR UPDATE;
-- name: GetChatProviders :many
SELECT
*
FROM
chat_providers
ORDER BY
provider ASC;
-- name: GetEnabledChatProviders :many
SELECT
*
FROM
chat_providers
WHERE
enabled = TRUE
ORDER BY
provider ASC;
-- name: InsertChatProvider :one
INSERT INTO chat_providers (
provider,
display_name,
api_key,
base_url,
api_key_key_id,
created_by,
enabled,
central_api_key_enabled,
allow_user_api_key,
allow_central_api_key_fallback
) VALUES (
@provider::text,
@display_name::text,
@api_key::text,
@base_url::text,
sqlc.narg('api_key_key_id')::text,
sqlc.narg('created_by')::uuid,
@enabled::boolean,
@central_api_key_enabled::boolean,
@allow_user_api_key::boolean,
@allow_central_api_key_fallback::boolean
)
RETURNING
*;
-- name: UpdateChatProvider :one
UPDATE
chat_providers
SET
display_name = @display_name::text,
api_key = @api_key::text,
base_url = @base_url::text,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
enabled = @enabled::boolean,
central_api_key_enabled = @central_api_key_enabled::boolean,
allow_user_api_key = @allow_user_api_key::boolean,
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
-- name: DeleteChatProviderByID :exec
DELETE FROM
chat_providers
WHERE
id = @id::uuid;
@@ -1,20 +0,0 @@
-- name: GetUserChatProviderKeys :many
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
-- name: UpsertUserChatProviderKey :one
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
api_key = @api_key,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
updated_at = NOW()
RETURNING *;
-- name: UpdateUserChatProviderKey :one
UPDATE user_chat_provider_keys
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
RETURNING *;
-- name: DeleteUserChatProviderKey :exec
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
-4
View File
@@ -25,8 +25,6 @@ const (
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
@@ -99,8 +97,6 @@ const (
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id);
UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id);
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
+183 -868
View File
File diff suppressed because it is too large Load Diff
+87 -530
View File
@@ -890,26 +890,6 @@ func TestPostChats_ClientType(t *testing.T) {
func TestListChats(t *testing.T) {
t.Parallel()
sortedChatIDs := func(ids []uuid.UUID) []uuid.UUID {
out := append([]uuid.UUID(nil), ids...)
slices.SortFunc(out, func(a, b uuid.UUID) int {
return strings.Compare(a.String(), b.String())
})
return out
}
requireRootIDs := func(t *testing.T, chats []codersdk.Chat, want ...uuid.UUID) []uuid.UUID {
t.Helper()
got := make([]uuid.UUID, 0, len(chats))
for _, chat := range chats {
require.Nil(t, chat.ParentChatID, "list should only return root chats")
got = append(got, chat.ID)
}
require.Equal(t, sortedChatIDs(want), sortedChatIDs(got))
return got
}
t.Run("Success", func(t *testing.T) {
t.Parallel()
@@ -1559,171 +1539,6 @@ func TestListChats(t *testing.T) {
require.Equal(t, archivedWithPR.ID, chats[0].ID)
})
})
t.Run("TitleSearch", func(t *testing.T) {
t.Parallel()
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Verify that the title: filter is wired through the endpoint.
// Exhaustive ILIKE behavior is tested in TestGetChatsFilter (Title/* subtests).
alpha := dbgen.Chat(t, db, database.Chat{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "alpha project",
})
_ = dbgen.Chat(t, db, database.Chat{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "beta unrelated",
})
ctx := testutil.Context(t, testutil.WaitLong)
t.Run("SingleWord", func(t *testing.T) {
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "title:alpha"})
require.NoError(t, err)
requireRootIDs(t, chats, alpha.ID)
})
t.Run("MultiWord", func(t *testing.T) {
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: `title:"alpha project"`})
require.NoError(t, err)
requireRootIDs(t, chats, alpha.ID)
})
t.Run("BareTermsRejected", func(t *testing.T) {
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "bare words"})
requireSDKError(t, err, http.StatusBadRequest)
})
})
t.Run("PRStatusFilter", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Verify that pr_status filter is wired through the endpoint.
// Exhaustive query logic is tested in TestGetChatsFilter (PRStatus/* subtests).
createChatWithPR := func(title, prURL, prState string, prDraft bool) database.Chat {
t.Helper()
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: title,
Status: database.ChatStatusCompleted,
})
refreshedAt := time.Now().UTC().Truncate(time.Second)
staleAt := refreshedAt.Add(time.Hour)
_, err := db.UpsertChatDiffStatusReference(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
Url: sql.NullString{String: prURL, Valid: true},
GitBranch: "feature/test",
GitRemoteOrigin: "git@github.com:coder/coder.git",
StaleAt: staleAt,
},
)
require.NoError(t, err)
_, err = db.UpsertChatDiffStatus(
dbauthz.AsSystemRestricted(ctx),
database.UpsertChatDiffStatusParams{
ChatID: chat.ID,
Url: sql.NullString{String: prURL, Valid: true},
PullRequestState: sql.NullString{String: prState, Valid: true},
PullRequestDraft: prDraft,
RefreshedAt: refreshedAt,
StaleAt: staleAt,
},
)
require.NoError(t, err)
return chat
}
draftChat := createChatWithPR("draft pr", "https://github.com/coder/coder/pull/301", "open", true)
_ = createChatWithPR("open pr", "https://github.com/coder/coder/pull/302", "open", false)
t.Run("MatchesDraft", func(t *testing.T) {
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:draft"})
require.NoError(t, err)
requireRootIDs(t, chats, draftChat.ID)
})
t.Run("InvalidPRStatus", func(t *testing.T) {
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:bogus"})
requireSDKError(t, err, http.StatusBadRequest)
})
})
t.Run("UnreadFilter", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Verify that has_unread:true filter is wired through the endpoint.
// Exhaustive query logic is tested in TestGetChatsFilter (Unread/* subtests).
unreadChat := dbgen.Chat(t, db, database.Chat{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "unread chat",
Status: database.ChatStatusCompleted,
})
_, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: unreadChat.ID,
CreatedBy: []uuid.UUID{user.UserID},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{`[{"type":"text","text":"hello"}]`},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
ProviderResponseID: []string{""},
})
require.NoError(t, err)
// Create a second chat with NO unread messages to prove filtering works.
_ = dbgen.Chat(t, db, database.Chat{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "read chat",
Status: database.ChatStatusCompleted,
})
t.Run("MatchesUnread", func(t *testing.T) {
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:true"})
require.NoError(t, err)
requireRootIDs(t, chats, unreadChat.ID)
})
t.Run("InvalidHasUnread", func(t *testing.T) {
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:bogus"})
requireSDKError(t, err, http.StatusBadRequest)
})
})
}
func TestListChatModels(t *testing.T) {
@@ -1802,18 +1617,12 @@ func TestListChatModels(t *testing.T) {
_ = coderdtest.CreateFirstUser(t, client.Client)
providerType := database.AiProviderTypeAnthropic
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: string(providerType),
CentralAPIKeyEnabled: ptr.Ref(false),
AllowUserAPIKey: ptr.Ref(true),
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, string(providerType), "")
provider := createAIProviderForTest(t, client, string(providerType), "")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(providerType),
AIProviderID: &aiProvider.ID,
AIProviderID: &provider.ID,
Model: "claude-sonnet",
ContextLimit: &contextLimit,
})
@@ -1833,7 +1642,7 @@ func TestListChatModels(t *testing.T) {
require.False(t, anthropicProvider.Available)
require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason)
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
_, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
APIKey: "user-api-key",
})
require.NoError(t, err)
@@ -1859,20 +1668,12 @@ func TestListChatModels(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "google",
APIKey: "central-api-key",
CentralAPIKeyEnabled: ptr.Ref(true),
AllowUserAPIKey: ptr.Ref(true),
AllowCentralAPIKeyFallback: ptr.Ref(true),
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "google", "provider-api-key")
provider := createAIProviderForTest(t, client, "google", "provider-api-key")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "google",
AIProviderID: &aiProvider.ID,
AIProviderID: &provider.ID,
Model: "gemini-1.5-pro",
ContextLimit: &contextLimit,
})
@@ -1891,7 +1692,7 @@ func TestListChatModels(t *testing.T) {
require.NotNil(t, googleProvider)
require.True(t, googleProvider.Available)
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
_, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
APIKey: "user-api-key",
})
require.NoError(t, err)
@@ -1919,17 +1720,12 @@ func TestListChatModels(t *testing.T) {
client := newChatClientWithDeploymentValues(t, values)
_ = coderdtest.CreateFirstUser(t, client.Client)
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-key")
provider := createAIProviderForTest(t, client, "openai", "test-key")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
AIProviderID: &provider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
@@ -1943,7 +1739,7 @@ func TestListChatModels(t *testing.T) {
require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model)
enabled := false
_, err = client.UpdateChatProvider(ctx, chatProvider.ID, codersdk.UpdateChatProviderConfigRequest{
_, err = client.UpdateAIProvider(ctx, provider.ID.String(), codersdk.UpdateAIProviderRequest{
Enabled: &enabled,
})
require.NoError(t, err)
@@ -2450,6 +2246,7 @@ func TestUserAIProviderKeys(t *testing.T) {
func TestListChatProviders(t *testing.T) {
t.Parallel()
t.Skip("legacy chat provider API removed in favor of AI provider API")
t.Run("Success", func(t *testing.T) {
t.Parallel()
@@ -2519,6 +2316,7 @@ func TestListChatProviders(t *testing.T) {
func TestCreateChatProvider(t *testing.T) {
t.Parallel()
t.Skip("legacy chat provider API removed in favor of AI provider API")
t.Run("Success", func(t *testing.T) {
t.Parallel()
@@ -2819,6 +2617,7 @@ func TestCreateChatProvider(t *testing.T) {
func TestUpdateChatProvider(t *testing.T) {
t.Parallel()
t.Skip("legacy chat provider API removed in favor of AI provider API")
t.Run("Success", func(t *testing.T) {
t.Parallel()
@@ -3111,282 +2910,7 @@ func TestUpdateChatProvider(t *testing.T) {
func TestDeleteChatProvider(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
err = client.DeleteChatProvider(ctx, provider.ID)
require.NoError(t, err)
providers, err := client.ListChatProviders(ctx)
require.NoError(t, err)
for _, listed := range providers {
require.NotEqual(t, provider.ID, listed.ID)
}
})
t.Run("SuccessWithHistoricalChats", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
providerToDelete, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "delete-api-key",
AllowUserAPIKey: ptr.Ref(true),
})
require.NoError(t, err)
aiProviderToDelete := createAIProviderForTest(t, client, providerToDelete.Provider, "delete-api-key")
deleteContextLimit := int64(4096)
deleteIsDefault := true
configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: providerToDelete.Provider,
AIProviderID: &aiProviderToDelete.ID,
Model: "gpt-4o-delete-provider",
ContextLimit: &deleteContextLimit,
IsDefault: &deleteIsDefault,
})
require.NoError(t, err)
keepProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
APIKey: "keep-api-key",
})
require.NoError(t, err)
keepAIProvider := createAIProviderForTest(t, client, keepProvider.Provider, "keep-api-key")
keepContextLimit := int64(8192)
keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: keepProvider.Provider,
AIProviderID: &keepAIProvider.ID,
Model: "claude-keep-provider",
ContextLimit: &keepContextLimit,
})
require.NoError(t, err)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
ModelConfigID: ptr.Ref(configToDelete.ID),
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "provider delete history " + t.Name(),
}},
})
require.NoError(t, err)
require.Equal(t, configToDelete.ID, chat.LastModelConfigID)
insertAssistantCostMessage(t, db, chat.ID, configToDelete.ID, 500)
_, err = client.UpsertUserChatProviderKey(ctx, providerToDelete.ID, codersdk.CreateUserChatProviderKeyRequest{
APIKey: "user-delete-key",
})
require.NoError(t, err)
userKeys, err := db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID)
require.NoError(t, err)
require.Len(t, userKeys, 1)
require.Equal(t, providerToDelete.ID, userKeys[0].ChatProviderID)
err = client.DeleteChatProvider(ctx, providerToDelete.ID)
require.NoError(t, err)
_, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), providerToDelete.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
providers, err := client.ListChatProviders(ctx)
require.NoError(t, err)
foundKeepProvider := false
for _, listed := range providers {
require.NotEqual(t, providerToDelete.ID, listed.ID)
if listed.ID == keepProvider.ID {
foundKeepProvider = true
}
}
require.True(t, foundKeepProvider)
configs, err := client.ListChatModelConfigs(ctx)
require.NoError(t, err)
foundDeletedConfig := false
foundKeepConfig := false
for _, config := range configs {
if config.ID == configToDelete.ID {
foundDeletedConfig = true
}
if config.ID == keepConfig.ID {
foundKeepConfig = true
require.True(t, config.IsDefault)
}
}
require.False(t, foundDeletedConfig)
require.True(t, foundKeepConfig)
defaultConfig, err := db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx))
require.NoError(t, err)
require.Equal(t, keepConfig.ID, defaultConfig.ID)
_, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), configToDelete.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
gotChat, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, chat.ID, gotChat.ID)
require.Equal(t, configToDelete.ID, gotChat.LastModelConfigID)
messages, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
foundHistoricalMessage := false
for _, message := range messages.Messages {
if message.ModelConfigID != nil && *message.ModelConfigID == configToDelete.ID {
foundHistoricalMessage = true
break
}
}
require.True(t, foundHistoricalMessage)
userKeys, err = db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID)
require.NoError(t, err)
require.Empty(t, userKeys)
})
t.Run("SuccessWithHistoricalChatsAndNoReplacementConfig", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "only-provider-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, provider.Provider, "only-provider-api-key")
contextLimit := int64(4096)
isDefault := true
config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-only-provider",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
ModelConfigID: ptr.Ref(config.ID),
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "only provider delete history " + t.Name(),
}},
})
require.NoError(t, err)
require.Equal(t, config.ID, chat.LastModelConfigID)
insertAssistantCostMessage(t, db, chat.ID, config.ID, 250)
err = client.DeleteChatProvider(ctx, provider.ID)
require.NoError(t, err)
providers, err := client.ListChatProviders(ctx)
require.NoError(t, err)
for _, listed := range providers {
require.NotEqual(t, provider.ID, listed.ID)
}
_, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), provider.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
_, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), config.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
_, err = db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx))
require.ErrorIs(t, err, sql.ErrNoRows)
configs, err := client.ListChatModelConfigs(ctx)
require.NoError(t, err)
require.Empty(t, configs)
gotChat, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, config.ID, gotChat.LastModelConfigID)
messages, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
foundHistoricalMessage := false
for _, message := range messages.Messages {
if message.ModelConfigID != nil && *message.ModelConfigID == config.ID {
foundHistoricalMessage = true
break
}
}
require.True(t, foundHistoricalMessage)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
err := client.DeleteChatProvider(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("InvalidProviderID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
res, err := client.Request(
ctx,
http.MethodDelete,
"/api/experimental/chats/providers/not-a-uuid",
nil,
)
require.NoError(t, err)
defer res.Body.Close()
err = codersdk.ReadBodyAsError(res)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid chat provider ID.", sdkErr.Message)
})
t.Run("ForbiddenForOrganizationMember", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
err = memberClient.DeleteChatProvider(ctx, provider.ID)
requireSDKError(t, err, http.StatusForbidden)
})
t.Skip("legacy chat provider API removed in favor of AI provider API")
}
func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) {
@@ -3414,6 +2938,7 @@ func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) {
func TestUserChatProviderConfigs(t *testing.T) {
t.Parallel()
t.Skip("legacy chat provider API removed in favor of AI provider API")
requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig {
t.Helper()
@@ -3809,6 +3334,7 @@ func TestUserChatProviderConfigs(t *testing.T) {
func TestUpsertUserChatProviderKey(t *testing.T) {
t.Parallel()
t.Skip("legacy chat provider API removed in favor of AI provider API")
t.Run("RejectsTooLargeAPIKey", func(t *testing.T) {
t.Parallel()
@@ -4263,6 +3789,26 @@ func TestUpdateChatModelConfig(t *testing.T) {
requireChatModelPricing(t, configs[0].ModelConfig, pricing)
})
t.Run("UnchangedProviderWithoutAIProviderID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Provider: modelConfig.Provider,
Model: "gpt-4o-mini-updated",
})
require.NoError(t, err)
require.Equal(t, modelConfig.ID, updated.ID)
require.Equal(t, modelConfig.Provider, updated.Provider)
require.NotNil(t, updated.AIProviderID)
require.Equal(t, *modelConfig.AIProviderID, *updated.AIProviderID)
require.Equal(t, "gpt-4o-mini-updated", updated.Model)
})
t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) {
t.Parallel()
@@ -4482,11 +4028,12 @@ func TestUpdateChatModelConfig(t *testing.T) {
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
missingProviderID := uuid.New()
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Provider: "anthropic",
AIProviderID: &missingProviderID,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
})
t.Run("NotFoundWhenTargetRowDisappearsInTx", func(t *testing.T) {
@@ -10953,32 +10500,27 @@ func enableUserChatProviderKey(
adminClient *codersdk.ExperimentalClient,
userClient *codersdk.ExperimentalClient,
providerName string,
) codersdk.ChatProviderConfig {
) codersdk.AIProvider {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
providers, err := adminClient.ListChatProviders(ctx)
providers, err := adminClient.AIProviders(ctx)
require.NoError(t, err)
var provider codersdk.ChatProviderConfig
var provider codersdk.AIProvider
for _, candidate := range providers {
if candidate.Provider == providerName && candidate.Source == codersdk.ChatProviderConfigSourceDatabase {
if candidate.Type == codersdk.AIProviderType(providerName) {
provider = candidate
break
}
}
require.NotEqual(t, uuid.Nil, provider.ID)
updated, err := adminClient.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{
AllowUserAPIKey: ptr.Ref(true),
})
require.NoError(t, err)
_, err = userClient.UpsertUserChatProviderKey(ctx, updated.ID, codersdk.CreateUserChatProviderKeyRequest{
_, err = userClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
APIKey: "test-user-api-key-" + uuid.NewString(),
})
require.NoError(t, err)
return updated
return provider
}
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
@@ -11792,13 +11334,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) {
defaultModelConfig := createChatModelConfig(t, adminClient)
provider := enableUserChatProviderKey(t, adminClient, memberClient, defaultModelConfig.Provider)
modelConfig := createAdditionalChatModelConfig(
t,
adminClient,
defaultModelConfig.Provider,
"gpt-4o-personal-"+uuid.NewString(),
)
err := adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{
modelProvider := createAIProviderForTest(t, adminClient, "anthropic", "")
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", modelProvider.ID, codersdk.CreateUserAIProviderKeyRequest{
APIKey: "test-user-api-key-" + uuid.NewString(),
})
require.NoError(t, err)
contextLimit := int64(4096)
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
AIProviderID: &modelProvider.ID,
Model: "claude-personal-" + uuid.NewString(),
ContextLimit: &contextLimit,
})
require.NoError(t, err)
err = adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{
ModelConfigID: modelConfig.ID.String(),
})
require.NoError(t, err)
@@ -11813,19 +11362,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) {
defaultModelConfig.Provider,
"gpt-4o-personal-disabled-"+uuid.NewString(),
)
disabledProvider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
Enabled: ptr.Ref(false),
CentralAPIKeyEnabled: ptr.Ref(false),
AllowUserAPIKey: ptr.Ref(true),
disabledProvider := createAIProviderForTest(t, adminClient, "google", "test-api-key")
contextLimit = int64(4096)
disabledProviderModelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "google",
AIProviderID: &disabledProvider.ID,
Model: "gemini-personal-disabled-provider-" + uuid.NewString(),
ContextLimit: &contextLimit,
})
require.NoError(t, err)
enabled := false
disabledProvider, err = adminClient.UpdateAIProvider(ctx, disabledProvider.ID.String(), codersdk.UpdateAIProviderRequest{
Enabled: &enabled,
})
require.NoError(t, err)
disabledProviderModelConfig := createAdditionalChatModelConfig(
t,
adminClient,
"anthropic",
"claude-personal-disabled-provider-"+uuid.NewString(),
)
require.NotEqual(t, uuid.Nil, provider.ID)
require.NotEqual(t, uuid.Nil, disabledProvider.ID)
@@ -12153,12 +11703,19 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) {
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
defaultModel := createChatModelConfig(t, adminClient)
_ = enableUserChatProviderKey(t, adminClient, adminClient, defaultModel.Provider)
overrideModel := createAdditionalChatModelConfig(
t,
adminClient,
defaultModel.Provider,
"gpt-4o-root-personal-"+uuid.NewString(),
)
overrideProvider := createAIProviderForTest(t, adminClient, "anthropic", "")
_, err := adminClient.UpsertUserAIProviderKey(ctx, "me", overrideProvider.ID, codersdk.CreateUserAIProviderKeyRequest{
APIKey: "test-user-api-key-" + uuid.NewString(),
})
require.NoError(t, err)
contextLimit := int64(4096)
overrideModel, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
AIProviderID: &overrideProvider.ID,
Model: "claude-root-personal-" + uuid.NewString(),
ContextLimit: &contextLimit,
})
require.NoError(t, err)
disabledModel := createDisabledChatModelConfig(
t,
adminClient,
@@ -12203,7 +11760,7 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) {
require.NoError(t, err)
}
err := adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{
err = adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{
AllowUsers: true,
})
require.NoError(t, err)
@@ -2,6 +2,7 @@ package coderd
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
@@ -96,15 +97,19 @@ func insertAgentChatTestModelConfig(
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
provider := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "test-openai",
DisplayName: sql.NullString{String: "OpenAI", Valid: true},
})
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: provider.ID,
APIKey: "test-api-key",
CreatedBy: createdBy,
})
return dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
CreatedBy: createdBy,
UpdatedBy: createdBy,
IsDefault: true,
+2 -2
View File
@@ -24,8 +24,8 @@ import (
// advisorOverrideStubStore stubs only the database methods that
// resolveAdvisorModelOverride exercises. The prod code calls
// GetEnabledChatModelConfigByID so the query joins chat_providers and
// filters both enabled flags atomically; tests simulate that by returning
// GetEnabledChatModelConfigByID so the query joins ai_providers and
// filters both enabled flags atomically. Tests simulate that by returning
// configs the stub treats as enabled.
type advisorOverrideStubStore struct {
database.Store
+66 -48
View File
@@ -350,13 +350,8 @@ func (p *Server) resolveAdvisorModelOverride(
return fallbackModel, fallbackCallConfig
}
// GetEnabledChatModelConfigByID checks the model config and referenced
// provider enabled state, so it returns sql.ErrNoRows the moment an
// admin disables either one. Using the cached ModelConfigByID here
// would keep resolving an override whose provider was just disabled,
// and an available fallback key would let ModelFromConfig succeed,
// silently routing advisor prompts to a provider the admin expects to
// be off.
// Re-read the override instead of using the cache so disabled models
// or providers stop routing advisor prompts immediately.
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
ctx,
advisorCfg.ModelConfigID,
@@ -7102,9 +7097,8 @@ func (p *Server) runChat(
// Fire title generation asynchronously so it doesn't block the
// chat response. It uses a detached context so it can finish
// even after the chat processing context is canceled.
// Snapshot model, provider keys, logger, and ctx before launch; all four get
// reassigned below (model = cuModel, providerKeys = computerUseProviderKeys,
// logger = logger.With(...), ctx = runCtx) and the goroutine captures by reference.
// Snapshot values captured by the goroutine because model, providerKeys,
// logger, and ctx are reassigned below.
titleModel := model
titleProviderKeys := providerKeys
titleLogger := logger
@@ -8502,13 +8496,17 @@ func (p *Server) resolveChatModel(
}
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
if !provider.Enabled {
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID)
if err != nil {
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err)
}
return p.aiProviderConfigFromKeys(provider, keys)
}
func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) {
if !provider.Enabled {
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
apiKey := ""
// GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes
// one provider-scoped key because runtime provider config has one API key slot.
@@ -8529,6 +8527,48 @@ func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvi
}, nil
}
func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) {
if len(providers) == 0 {
return nil, nil
}
providerIDs := make([]uuid.UUID, 0, len(providers))
for _, provider := range providers {
providerIDs = append(providerIDs, provider.ID)
}
keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
if err != nil {
return nil, xerrors.Errorf("get AI provider keys: %w", err)
}
keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers))
for _, key := range keys {
keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key)
}
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
for _, provider := range providers {
configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID])
if err != nil {
return nil, err
}
configuredProviders = append(configuredProviders, configuredProvider)
}
return configuredProviders, nil
}
func ensureUniqueConfiguredProviderTypes(providers []chatprovider.ConfiguredProvider) error {
seen := make(map[string]uuid.UUID, len(providers))
for _, provider := range providers {
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
if existingProviderID, ok := seen[normalizedProvider]; ok && existingProviderID != provider.ProviderID {
return xerrors.Errorf("multiple enabled AI providers use provider type %q; select an AI provider by ID", normalizedProvider)
}
seen[normalizedProvider] = provider.ProviderID
}
return nil
}
func (p *Server) resolveUserProviderAPIKeysForProvider(
ctx context.Context,
ownerID uuid.UUID,
@@ -8559,12 +8599,6 @@ func (p *Server) resolveUserProviderAPIKeysForProvider(
[]chatprovider.ConfiguredProvider{configuredProvider},
userKeys,
)
enabledProviders := map[string]struct{}{}
normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider)
if normalizedProvider != "" {
enabledProviders[normalizedProvider] = struct{}{}
}
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
return keys, nil
}
@@ -8598,9 +8632,6 @@ func (p *Server) resolveUserProviderAPIKeys(
ownerID uuid.UUID,
selectedAIProviderID uuid.UUID,
) (chatprovider.ProviderAPIKeys, error) {
var configuredProviders []chatprovider.ConfiguredProvider
userKeys := []chatprovider.UserProviderKey{}
if selectedAIProviderID != uuid.Nil {
provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID)
if err != nil {
@@ -8608,48 +8639,35 @@ func (p *Server) resolveUserProviderAPIKeys(
}
return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
}
providers, err := p.configCache.EnabledProviders(ctx)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get enabled chat providers: %w",
"get enabled AI providers: %w",
err,
)
}
configuredProviders = make(
[]chatprovider.ConfiguredProvider, 0, len(providers),
)
for _, provider := range providers {
configuredProviders = append(
configuredProviders, chatprovider.ConfiguredProvider{
ProviderID: provider.ID,
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserAPIKey: provider.AllowUserApiKey,
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
},
)
configuredProviders, err := p.aiProviderConfigs(ctx, providers)
if err != nil {
return chatprovider.ProviderAPIKeys{}, err
}
allowAnyUserAPIKey := false
for _, provider := range configuredProviders {
if provider.AllowUserAPIKey {
allowAnyUserAPIKey = true
break
if err := ensureUniqueConfiguredProviderTypes(configuredProviders); err != nil {
return chatprovider.ProviderAPIKeys{}, err
}
}
if allowAnyUserAPIKey {
userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID)
userKeys := []chatprovider.UserProviderKey{}
if p.allowBYOK {
userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get user chat provider keys: %w",
"get user AI provider keys: %w",
err,
)
}
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
for _, userKey := range userKeyRows {
userKeys = append(userKeys, chatprovider.UserProviderKey{
ChatProviderID: userKey.ChatProviderID,
ChatProviderID: userKey.AIProviderID,
APIKey: userKey.APIKey,
})
}
+26 -17
View File
@@ -797,12 +797,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
providerID := uuid.New()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Enabled: true,
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
@@ -961,12 +963,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
providerID := uuid.New()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Enabled: true,
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
@@ -1110,11 +1114,13 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "anthropic",
CentralApiKeyEnabled: true,
AllowCentralApiKeyFallback: true,
providerID := uuid.New()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
ID: providerID,
Type: database.AiProviderTypeAnthropic,
Enabled: true,
}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
require.NoError(t, err)
@@ -1185,10 +1191,13 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
providerID := uuid.New()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Enabled: true,
}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
require.NoError(t, err)
@@ -3922,7 +3931,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
database.ChatModelConfig{}, xerrors.New("no model configured"),
).AnyTimes()
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
@@ -5696,7 +5705,7 @@ func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) {
return database.ChatModelConfig{}, chatloop.ErrInterrupted
},
).AnyTimes()
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
+15 -11
View File
@@ -6120,21 +6120,24 @@ func setOpenAIProviderBaseURL(
) {
t.Helper()
provider, err := db.GetChatProviderByProvider(ctx, "openai")
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
for _, provider := range providers {
if provider.Type != database.AiProviderTypeOpenai {
continue
}
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
BaseUrl: baseURL,
Settings: provider.Settings,
SettingsKeyID: provider.SettingsKeyID,
})
require.NoError(t, err)
return
}
require.Fail(t, "openai provider not found")
}
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
@@ -7193,9 +7196,10 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) {
true,
false,
)
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
ID: uuid.New(),
UserID: user.ID,
ChatProviderID: provider.ID,
AIProviderID: provider.ID,
APIKey: userAPIKey,
})
require.NoError(t, err)
+7 -7
View File
@@ -30,7 +30,7 @@ const (
)
type cachedProviders struct {
providers []database.ChatProvider
providers []database.AIProvider
expiresAt time.Time
}
@@ -74,7 +74,7 @@ type chatConfigCache struct {
// Providers (singleton).
providers *cachedProviders
providerGeneration uint64
providerFetches singleflight.Group[string, []database.ChatProvider]
providerFetches singleflight.Group[string, []database.AIProvider]
// Model configs (keyed by ID).
modelTopologyEpoch uint64
@@ -131,7 +131,7 @@ func singleflightDoChan[K comparable, V any](
}
}
func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.ChatProvider, error) {
func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.AIProvider, error) {
if providers, ok := c.cachedProviders(); ok {
return providers, nil
}
@@ -141,12 +141,12 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat
ctx,
&c.providerFetches,
fmt.Sprintf("%d:providers", generation),
func() ([]database.ChatProvider, error) {
func() ([]database.AIProvider, error) {
if cached, ok := c.cachedProviders(); ok {
return cached, nil
}
fetched, err := c.db.GetEnabledChatProviders(c.ctx)
fetched, err := c.db.GetAIProviders(c.ctx, database.GetAIProvidersParams{})
if err != nil {
return nil, err
}
@@ -161,7 +161,7 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat
return slices.Clone(providers), nil
}
func (c *chatConfigCache) cachedProviders() ([]database.ChatProvider, bool) {
func (c *chatConfigCache) cachedProviders() ([]database.AIProvider, bool) {
c.mu.RLock()
entry := c.providers
c.mu.RUnlock()
@@ -188,7 +188,7 @@ func (c *chatConfigCache) providersGeneration() uint64 {
return generation
}
func (c *chatConfigCache) storeProviders(generation uint64, providers []database.ChatProvider) {
func (c *chatConfigCache) storeProviders(generation uint64, providers []database.AIProvider) {
c.mu.Lock()
defer c.mu.Unlock()
+30 -29
View File
@@ -22,7 +22,7 @@ import (
type stubChatConfigStore struct {
database.Store
getEnabledChatProviders func(context.Context) ([]database.ChatProvider, error)
getAIProviders func(context.Context) ([]database.AIProvider, error)
getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error)
getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error)
@@ -35,12 +35,12 @@ type stubChatConfigStore struct {
advisorConfigCalls atomic.Int32
}
func (s *stubChatConfigStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) {
s.enabledProvidersCalls.Add(1)
if s.getEnabledChatProviders == nil {
panic("unexpected GetEnabledChatProviders call")
if s.getAIProviders == nil {
panic("unexpected GetAIProviders call")
}
return s.getEnabledChatProviders(ctx)
return s.getAIProviders(ctx)
}
func (s *stubChatConfigStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
@@ -80,9 +80,9 @@ func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
providers := []database.ChatProvider{testChatProvider("provider-a")}
providers := []database.AIProvider{testAIProvider("provider-a")}
store := &stubChatConfigStore{
getEnabledChatProviders: func(context.Context) ([]database.ChatProvider, error) {
getAIProviders: func(context.Context) ([]database.AIProvider, error) {
return providers, nil
},
}
@@ -104,9 +104,9 @@ func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
}
cache := newChatConfigCache(ctx, store, clock)
@@ -126,9 +126,9 @@ func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
store := &stubChatConfigStore{}
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
}
cache := newChatConfigCache(ctx, store, clock)
@@ -398,12 +398,12 @@ func TestConfigCache_Singleflight(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
providers := []database.ChatProvider{testChatProvider("provider-a")}
providers := []database.AIProvider{testAIProvider("provider-a")}
fetchStarted := make(chan struct{})
releaseFetch := make(chan struct{})
var startedOnce sync.Once
store := &stubChatConfigStore{}
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
startedOnce.Do(func() { close(fetchStarted) })
<-releaseFetch
return providers, nil
@@ -411,7 +411,7 @@ func TestConfigCache_Singleflight(t *testing.T) {
cache := newChatConfigCache(ctx, store, clock)
const callers = 8
results := make([][]database.ChatProvider, callers)
results := make([][]database.AIProvider, callers)
errs := make([]error, callers)
var wg sync.WaitGroup
start := make(chan struct{})
@@ -441,13 +441,13 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
firstProviders := []database.ChatProvider{testChatProvider("provider-a")}
secondProviders := []database.ChatProvider{testChatProvider("provider-b")}
firstProviders := []database.AIProvider{testAIProvider("provider-a")}
secondProviders := []database.AIProvider{testAIProvider("provider-b")}
fetchStarted := make(chan struct{})
releaseFetch := make(chan struct{})
var startedOnce sync.Once
store := &stubChatConfigStore{}
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
call := store.enabledProvidersCalls.Load()
if call == 1 {
startedOnce.Do(func() { close(fetchStarted) })
@@ -458,7 +458,7 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) {
}
cache := newChatConfigCache(ctx, store, clock)
resultCh := make(chan []database.ChatProvider, 1)
resultCh := make(chan []database.AIProvider, 1)
errCh := make(chan error, 1)
go func() {
providers, err := cache.EnabledProviders(ctx)
@@ -494,14 +494,14 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing
ctx := testutil.Context(t, testutil.WaitMedium)
clock := quartz.NewMock(t)
staleProviders := []database.ChatProvider{testChatProvider("provider-stale")}
freshProviders := []database.ChatProvider{testChatProvider("provider-fresh")}
staleProviders := []database.AIProvider{testAIProvider("provider-stale")}
freshProviders := []database.AIProvider{testAIProvider("provider-fresh")}
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
releaseFirst := make(chan struct{})
releaseSecond := make(chan struct{})
store := &stubChatConfigStore{}
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
switch call := store.enabledProvidersCalls.Load(); call {
case 1:
close(firstStarted)
@@ -518,7 +518,7 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing
cache := newChatConfigCache(ctx, store, clock)
type result struct {
providers []database.ChatProvider
providers []database.AIProvider
err error
}
@@ -670,11 +670,12 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testi
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
}
func testChatProvider(name string) database.ChatProvider {
return database.ChatProvider{
func testAIProvider(name string) database.AIProvider {
return database.AIProvider{
ID: uuid.New(),
Provider: name,
DisplayName: name,
Type: database.AIProviderType(name),
Name: name,
DisplayName: sql.NullString{String: name, Valid: true},
Enabled: true,
CreatedAt: time.Unix(0, 0).UTC(),
UpdatedAt: time.Unix(0, 0).UTC(),
@@ -737,19 +738,19 @@ func TestConfigCache_CallerCancellation(t *testing.T) {
name: "EnabledProviders",
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
var once sync.Once
store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
once.Do(func() { close(started) })
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-release:
return []database.ChatProvider{testChatProvider("p")}, nil
return []database.AIProvider{testAIProvider("p")}, nil
}
}
},
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
var once sync.Once
store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) {
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
once.Do(func() { close(started) })
<-ctx.Done()
return nil, ctx.Err()
+2 -2
View File
@@ -154,12 +154,12 @@ func validateModelConfigAndResolveProvider(
}
func enabledProviderContainsName(
providers []database.ChatProvider,
providers []database.AIProvider,
providerName string,
) bool {
normalizedProviderName := chatprovider.NormalizeProvider(providerName)
for _, provider := range providers {
if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName {
if chatprovider.NormalizeProvider(string(provider.Type)) == normalizedProviderName {
return true
}
}
+35 -17
View File
@@ -2,6 +2,7 @@ package chatd
import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
@@ -183,12 +184,14 @@ func seedInternalChatDeps(
UserID: user.ID,
OrganizationID: org.ID,
})
dbgen.ChatProvider(t, db, database.ChatProvider{
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
IsDefault: true,
})
@@ -309,24 +312,38 @@ func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) {
require.True(t, keys.HasProvider("bedrock"))
require.Empty(t, keys.APIKey("bedrock"))
})
t.Run("RejectsAmbiguousProviderTypeWithoutSelectedProvider", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, _, _ := seedInternalChatDeps(t, db)
insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "first-provider-api-key", true)
insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "second-provider-api-key", true)
keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, uuid.Nil)
require.ErrorContains(t, err, "multiple enabled AI providers use provider type")
require.Equal(t, chatprovider.ProviderAPIKeys{}, keys)
})
}
func TestResolveChatModel_AIProviderDisabled(t *testing.T) {
t.Parallel()
ctx := chatdTestContext(t)
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
db, ps := dbtestutil.NewDB(t)
user, org, _ := seedInternalChatDeps(t, db)
provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", false)
modelConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
Model: "gpt-4o-mini",
AIProviderID: uuid.NullUUID{
UUID: provider.ID,
Valid: true,
},
})
_, err := sqlDB.ExecContext(ctx, "UPDATE chat_model_configs SET ai_provider_id = $1 WHERE id = $2", provider.ID, modelConfig.ID)
require.NoError(t, err)
loadedModelConfig, err := db.GetChatModelConfigByID(ctx, modelConfig.ID)
require.NoError(t, err)
require.True(t, loadedModelConfig.AIProviderID.Valid)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
@@ -408,19 +425,20 @@ func insertInternalChatProvider(
centralAPIKeyEnabled bool,
allowUserAPIKey bool,
allowCentralAPIKeyFallback bool,
) database.ChatProvider {
) database.AIProvider {
t.Helper()
providerConfig := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: provider,
DisplayName: provider,
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
}, func(p *database.InsertChatProviderParams) {
p.APIKey = apiKey
p.CentralApiKeyEnabled = centralAPIKeyEnabled
p.AllowUserApiKey = allowUserAPIKey
p.AllowCentralApiKeyFallback = allowCentralAPIKeyFallback
providerConfig := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AIProviderType(provider),
Name: "test-" + uuid.NewString(),
DisplayName: sql.NullString{String: provider, Valid: true},
})
if apiKey != "" {
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: providerConfig.ID,
APIKey: apiKey,
})
}
return providerConfig
}
+6 -3
View File
@@ -348,7 +348,8 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{uuid.Nil}).Return(nil, nil)
generated := &generatedChatTitle{}
server := titleOverrideTestServer(db, logger)
@@ -498,7 +499,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T)
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
server := titleOverrideTestServer(db, logger)
model, gotConfig, _, err := server.resolveManualTitleModel(
@@ -524,7 +526,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
server := titleOverrideTestServer(db, logger)
model, gotConfig, _, err := server.resolveManualTitleModel(
+4 -6
View File
@@ -33,16 +33,15 @@ func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) {
OrganizationID: org.ID,
})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
@@ -102,16 +101,15 @@ func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) {
OrganizationID: org.ID,
})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
+1 -1
View File
@@ -944,7 +944,7 @@ func TestWorker(t *testing.T) {
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
// 3. Set up FK chain: ai_providers -> chat_model_configs -> chats.
_ = dbgen.ChatProvider(t, db, database.ChatProvider{})
modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
+20 -11
View File
@@ -122,11 +122,17 @@ func seedChatDependencies(
UserID: user.ID,
OrganizationID: org.ID,
})
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
provider := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "test-" + uuid.NewString(),
BaseUrl: safetyNet.URL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
})
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: provider.ID,
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
IsDefault: true,
@@ -186,21 +192,24 @@ func setOpenAIProviderBaseURL(
) {
t.Helper()
provider, err := db.GetChatProviderByProvider(ctx, "openai")
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
for _, provider := range providers {
if provider.Type != database.AiProviderTypeOpenai {
continue
}
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
CentralApiKeyEnabled: true,
AllowUserApiKey: false,
AllowCentralApiKeyFallback: false,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
BaseUrl: baseURL,
Settings: provider.Settings,
SettingsKeyID: provider.SettingsKeyID,
})
require.NoError(t, err)
return
}
require.Fail(t, "openai provider not found")
}
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
+2 -105
View File
@@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
}
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if strings.TrimSpace(userProviderKey.APIKey) == "" {
continue
}
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
@@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if strings.TrimSpace(provider.APIKey) == "" {
continue
}
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
if err != nil {
return xerrors.Errorf("get ai providers: %w", err)
@@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
}
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if !userProviderKey.ApiKeyKeyID.Valid {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
}
userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid)
if err != nil {
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
@@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
providers, err := cryptDB.GetChatProviders(ctx)
if err != nil {
return xerrors.Errorf("get chat providers: %w", err)
}
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
for idx, provider := range providers {
if !provider.ApiKeyKeyID.Valid {
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
if err != nil {
return xerrors.Errorf("get ai providers: %w", err)
@@ -475,16 +378,10 @@ DELETE FROM user_links
DELETE FROM external_auth_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL;
DELETE FROM user_chat_provider_keys
WHERE api_key_key_id IS NOT NULL;
DELETE FROM user_ai_provider_keys
WHERE api_key_key_id IS NOT NULL;
DELETE FROM user_secrets
WHERE value_key_id IS NOT NULL;
UPDATE chat_providers
SET api_key = '',
api_key_key_id = NULL
WHERE api_key_key_id IS NOT NULL;
UPDATE ai_providers
SET settings = NULL,
settings_key_id = NULL
@@ -502,9 +399,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
store := database.New(sqlDB)
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
if err != nil {
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err)
}
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
log.Info(ctx, "deleted encrypted user tokens and AI provider API keys")
log.Info(ctx, "revoking all active keys")
keys, err := store.GetDBCryptKeys(ctx)
+13 -137
View File
@@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID
return keys, nil
}
func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptAIProviderKey(&keys[i]); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
@@ -576,92 +589,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data
return key, nil
}
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByID(ctx, id)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetEnabledChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.InsertChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.UpdateChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error {
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
}
@@ -754,57 +681,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params
return key, nil
}
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
}
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
// decryptMCPServerConfig decrypts all encrypted fields on a
// single MCPServerConfig in place.
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
+12 -107
View File
@@ -1281,6 +1281,18 @@ func TestAIProviderKeys(t *testing.T) {
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
})
t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID})
require.NoError(t, err)
require.Len(t, keys, 1)
requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey)
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
})
t.Run("DeleteAIProviderKey", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
@@ -1558,113 +1570,6 @@ func TestMCPServerUserTokens(t *testing.T) {
})
}
func TestUserChatProviderKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
//nolint:gosec // test credentials
initialAPIKey = "sk-initial-api-key-value"
//nolint:gosec // test credentials
updatedAPIKey = "sk-updated-api-key-value"
)
insertProviderAndKey := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) (database.ChatProvider, database.UserChatProviderKey) {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{
AllowUserApiKey: true,
}, func(params *database.InsertChatProviderParams) {
params.APIKey = ""
})
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: user.ID,
ChatProviderID: provider.ID,
APIKey: initialAPIKey,
})
require.NoError(t, err)
require.Equal(t, initialAPIKey, key.APIKey)
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
return provider, key
}
getUserChatProviderKey := func(t *testing.T, store interface {
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
}, userID uuid.UUID, providerID uuid.UUID,
) database.UserChatProviderKey {
t.Helper()
keys, err := store.GetUserChatProviderKeys(ctx, userID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, providerID, keys[0].ChatProviderID)
return keys[0]
}
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, key.ID, got.ID)
require.Equal(t, initialAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
})
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
t.Parallel()
_, crypt, ciphers := setup(t)
_, key := insertProviderAndKey(t, crypt, ciphers)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, key.ID, keys[0].ID)
require.Equal(t, initialAPIKey, keys[0].APIKey)
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
})
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: key.UserID,
ChatProviderID: provider.ID,
APIKey: updatedAPIKey,
})
require.NoError(t, err)
require.Equal(t, key.ID, updated.ID)
require.Equal(t, key.CreatedAt, updated.CreatedAt)
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
require.Equal(t, updatedAPIKey, updated.APIKey)
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, updatedAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, updatedAPIKey, keys[0].APIKey)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
})
}
func TestUserSecrets(t *testing.T) {
t.Parallel()
ctx := context.Background()
+5
View File
@@ -60,10 +60,15 @@ else
echo "Base ref $base not found locally, fetching $ref..."
git fetch origin "$ref" --depth=1 2>/dev/null || true
if ! git rev-parse --verify "$base" >/dev/null 2>&1; then
if git rev-parse --verify origin/main >/dev/null 2>&1; then
echo "WARNING: could not fetch base ref $base, falling back to origin/main merge base."
base=$(git merge-base HEAD origin/main 2>/dev/null || echo "origin/main")
else
echo "ERROR: could not fetch base ref $base."
exit 1
fi
fi
fi
found=0
if ! diff_output=$(git diff "$base" -U0 -- . "${exclude_pathspecs[@]}" 2>&1); then
-9
View File
@@ -284,11 +284,6 @@ describe("api.ts", () => {
providers: [],
},
],
[
"/api/experimental/chats/providers",
() => API.experimental.getChatProviderConfigs(),
[],
],
[
"/api/experimental/chats/model-configs",
() => API.experimental.getChatModelConfigs(),
@@ -310,10 +305,6 @@ describe("api.ts", () => {
"/api/experimental/chats/models",
() => API.experimental.getChatModels(),
],
[
"/api/experimental/chats/providers",
() => API.experimental.getChatProviderConfigs(),
],
[
"/api/experimental/chats/model-configs",
() => API.experimental.getChatModelConfigs(),
-67
View File
@@ -407,11 +407,8 @@ export type DeploymentConfig = Readonly<{
options: TypesGen.SerpentOption[];
}>;
const chatProviderConfigsPath = "/api/experimental/chats/providers";
const aiProviderConfigsPath = "/api/v2/ai/providers";
const chatModelConfigsPath = "/api/experimental/chats/model-configs";
const userChatProviderConfigsPath =
"/api/experimental/chats/user-provider-configs";
const userSkillsPath = (user: string) =>
`/api/experimental/users/${encodeURIComponent(user)}/skills`;
const userSkillPath = (user: string, name: string) =>
@@ -3731,42 +3728,6 @@ class ExperimentalApiMethods {
);
};
getChatProviderConfigs = async (): Promise<TypesGen.ChatProviderConfig[]> => {
const response = await this.axios.get<TypesGen.ChatProviderConfig[]>(
chatProviderConfigsPath,
);
return response.data;
};
createChatProviderConfig = async (
req: TypesGen.CreateChatProviderConfigRequest,
): Promise<TypesGen.ChatProviderConfig> => {
const response = await this.axios.post<TypesGen.ChatProviderConfig>(
chatProviderConfigsPath,
req,
);
return response.data;
};
updateChatProviderConfig = async (
providerConfigId: string,
req: TypesGen.UpdateChatProviderConfigRequest,
): Promise<TypesGen.ChatProviderConfig> => {
const response = await this.axios.patch<TypesGen.ChatProviderConfig>(
`${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
req,
);
return response.data;
};
deleteChatProviderConfig = async (
providerConfigId: string,
): Promise<void> => {
await this.axios.delete(
`${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
);
};
getChatModelConfigs = async (): Promise<TypesGen.ChatModelConfig[]> => {
const response =
await this.axios.get<TypesGen.ChatModelConfig[]>(chatModelConfigsPath);
@@ -3800,34 +3761,6 @@ class ExperimentalApiMethods {
);
};
getUserChatProviderConfigs = async (): Promise<
TypesGen.UserChatProviderConfig[]
> => {
const response = await this.axios.get<TypesGen.UserChatProviderConfig[]>(
userChatProviderConfigsPath,
);
return response.data;
};
upsertUserChatProviderKey = async (
providerConfigId: string,
req: TypesGen.CreateUserChatProviderKeyRequest,
): Promise<TypesGen.UserChatProviderConfig> => {
const response = await this.axios.put<TypesGen.UserChatProviderConfig>(
`${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
req,
);
return response.data;
};
deleteUserChatProviderKey = async (
providerConfigId: string,
): Promise<void> => {
await this.axios.delete(
`${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
);
};
getMCPServerConfigs = async (): Promise<TypesGen.MCPServerConfig[]> => {
const response =
await this.axios.get<TypesGen.MCPServerConfig[]>(mcpServerConfigsPath);
@@ -39,12 +39,36 @@ const createProviderConfig = (
updated_at: overrides.updated_at ?? now,
});
const createProviderKey = (providerId: string): TypesGen.AIProviderKey => ({
id: `key-${providerId}`,
masked: "sk-...test",
created_at: now,
});
const toAIProvider = (
providerConfig: TypesGen.ChatProviderConfig,
): TypesGen.AIProvider => ({
id: providerConfig.id,
type: providerConfig.provider as TypesGen.AIProviderType,
name: providerConfig.provider,
display_name: providerConfig.display_name,
enabled: providerConfig.enabled,
base_url: providerConfig.base_url ?? "",
api_keys: providerConfig.has_api_key
? [createProviderKey(providerConfig.id)]
: [],
settings: {},
created_at: providerConfig.created_at ?? now,
updated_at: providerConfig.updated_at ?? now,
});
const createModelConfig = (
overrides: Partial<TypesGen.ChatModelConfig> &
Pick<TypesGen.ChatModelConfig, "id" | "provider" | "model">,
): TypesGen.ChatModelConfig => ({
id: overrides.id,
provider: overrides.provider,
ai_provider_id: overrides.ai_provider_id,
model: overrides.model,
display_name: overrides.display_name ?? overrides.model,
enabled: overrides.enabled ?? true,
@@ -68,11 +92,9 @@ const setupChatSpies = (state: {
modelConfigs: TypesGen.ChatModelConfig[];
modelCatalog: TypesGen.ChatModelsResponse;
}) => {
spyOn(API.experimental, "getChatProviderConfigs").mockImplementation(
async () => {
return state.providerConfigs;
},
);
spyOn(API.experimental, "listAIProviders").mockImplementation(async () => {
return state.providerConfigs.map(toAIProvider);
});
spyOn(API.experimental, "getChatModelConfigs").mockImplementation(
async () => {
return state.modelConfigs;
@@ -82,31 +104,26 @@ const setupChatSpies = (state: {
return state.modelCatalog;
});
spyOn(API.experimental, "createChatProviderConfig").mockImplementation(
spyOn(API.experimental, "createAIProvider").mockImplementation(
async (req) => {
const created = createProviderConfig({
id: `provider-${Date.now()}`,
provider: req.provider ?? "",
provider: req.type ?? "openai",
display_name: req.display_name ?? "",
has_api_key: (req.api_key ?? "").trim().length > 0,
central_api_key_enabled: req.central_api_key_enabled ?? true,
allow_user_api_key: req.allow_user_api_key ?? false,
allow_central_api_key_fallback:
req.allow_central_api_key_fallback ?? false,
has_api_key:
req.api_keys?.some((apiKey) => apiKey.trim().length > 0) ?? false,
base_url: req.base_url ?? "",
enabled: req.enabled ?? true,
source: "database",
});
state.providerConfigs = [
...state.providerConfigs.filter(
(p) => !(p.id === nilProviderConfigID && p.provider === req.provider),
),
...state.providerConfigs.filter((p) => p.id !== created.id),
created,
];
return created;
return toAIProvider(created);
},
);
spyOn(API.experimental, "updateChatProviderConfig").mockImplementation(
spyOn(API.experimental, "updateAIProvider").mockImplementation(
async (providerConfigId, req) => {
const idx = state.providerConfigs.findIndex(
(p) => p.id === providerConfigId,
@@ -122,29 +139,30 @@ const setupChatSpies = (state: {
? req.display_name
: current.display_name,
has_api_key:
typeof req.api_key === "string"
? req.api_key.trim().length > 0
: current.has_api_key,
central_api_key_enabled:
typeof req.central_api_key_enabled === "boolean"
? req.central_api_key_enabled
: current.central_api_key_enabled,
allow_user_api_key:
typeof req.allow_user_api_key === "boolean"
? req.allow_user_api_key
: current.allow_user_api_key,
allow_central_api_key_fallback:
typeof req.allow_central_api_key_fallback === "boolean"
? req.allow_central_api_key_fallback
: current.allow_central_api_key_fallback,
req.api_keys === undefined
? current.has_api_key
: req.api_keys.some((apiKey) =>
apiKey.api_key !== undefined
? apiKey.api_key.trim().length > 0
: apiKey.id !== undefined,
),
base_url:
typeof req.base_url === "string" ? req.base_url : current.base_url,
enabled:
typeof req.enabled === "boolean" ? req.enabled : current.enabled,
updated_at: now,
};
state.providerConfigs = state.providerConfigs.map((p, i) =>
i === idx ? updated : p,
);
return updated;
return toAIProvider(updated);
},
);
spyOn(API.experimental, "deleteAIProvider").mockImplementation(
async (providerConfigId) => {
state.providerConfigs = state.providerConfigs.filter(
(p) => p.id !== providerConfigId,
);
},
);
@@ -154,6 +172,7 @@ const setupChatSpies = (state: {
id: `model-${state.modelConfigs.length + 1}`,
provider: req.provider ?? "",
model: req.model,
ai_provider_id: req.ai_provider_id,
display_name: req.display_name || req.model,
enabled: req.enabled ?? true,
context_limit:
@@ -181,10 +200,6 @@ const setupChatSpies = (state: {
},
);
// Unused but mock to avoid errors.
spyOn(API.experimental, "deleteChatProviderConfig").mockResolvedValue(
undefined,
);
spyOn(API.experimental, "updateChatModelConfig").mockImplementation(
async (modelConfigId, req) => {
const idx = state.modelConfigs.findIndex((m) => m.id === modelConfigId);