mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: remove legacy chat provider tables (#25416)
This commit is contained in:
@@ -47,10 +47,9 @@ func FakeOpenAICompatProviderAPIKeys(t testing.TB) chatprovider.ProviderAPIKeys
|
||||
}
|
||||
|
||||
// CreateOpenAICompatChatModelConfig creates the default provider and model
|
||||
// config used by chat runtime tests. Tests that create chats should also set
|
||||
// Options.ChatProviderAPIKeys, usually via FakeOpenAICompatProviderAPIKeys, so
|
||||
// background chat work routes to a local provider until coderd closes. baseURL,
|
||||
// when non-empty, is stored on the provider config.
|
||||
// config used by chat runtime tests. Tests can pass a baseURL to route chat work
|
||||
// to a specific local provider. If baseURL is empty, this helper starts a fake
|
||||
// OpenAI-compatible provider.
|
||||
func CreateOpenAICompatChatModelConfig(
|
||||
t testing.TB,
|
||||
client *codersdk.ExperimentalClient,
|
||||
@@ -58,26 +57,19 @@ func CreateOpenAICompatChatModelConfig(
|
||||
) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: TestChatProviderOpenAICompat,
|
||||
APIKey: TestChatProviderAPIKey,
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProviderBaseURL := baseURL
|
||||
if aiProviderBaseURL == "" {
|
||||
aiProviderBaseURL = "https://api.example.com/v1"
|
||||
if baseURL == "" {
|
||||
baseURL = chattest.OpenAI(t)
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderType(TestChatProviderOpenAICompat),
|
||||
Name: "test-" + uuid.NewString(),
|
||||
BaseURL: aiProviderBaseURL,
|
||||
BaseURL: baseURL,
|
||||
Enabled: true,
|
||||
APIKeys: []string{TestChatProviderAPIKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
|
||||
@@ -12,10 +12,9 @@ const (
|
||||
CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices
|
||||
CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsAiProviderRequiredWhenActive CheckConstraint = "chat_model_configs_ai_provider_required_when_active" // chat_model_configs
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
@@ -45,7 +44,6 @@ const (
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys
|
||||
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
|
||||
CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills
|
||||
CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills
|
||||
CheckUserSkillsNameFormat CheckConstraint = "user_skills_name_format" // user_skills
|
||||
|
||||
@@ -1967,6 +1967,13 @@ func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) e
|
||||
return q.db.DeleteChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAIProvider); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -1974,13 +1981,6 @@ func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider
|
||||
return q.db.DeleteChatModelConfigsByProvider(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
@@ -2313,17 +2313,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
@@ -2617,6 +2606,13 @@ func (q *querier) GetAIProviderKeysByProviderID(ctx context.Context, providerID
|
||||
return q.db.GetAIProviderKeysByProviderID(ctx, providerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
|
||||
return nil, err
|
||||
@@ -3128,41 +3124,6 @@ func (q *querier) GetChatPlanModeInstructions(ctx context.Context) (string, erro
|
||||
return q.db.GetChatPlanModeInstructions(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByIDForUpdate(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByProvider(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.GetChatProviderByProviderForUpdate(ctx, provider)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
@@ -3403,13 +3364,6 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch
|
||||
return q.db.GetEnabledChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
@@ -4661,17 +4615,6 @@ func (q *querier) GetUserChatPersonalModelOverride(ctx context.Context, arg data
|
||||
return q.db.GetUserChatPersonalModelOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserChatProviderKeys(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
@@ -5525,13 +5468,6 @@ func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.Insert
|
||||
return q.db.InsertChatModelConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.InsertChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
@@ -6749,13 +6685,6 @@ func (q *querier) UpdateChatPlanModeByID(ctx context.Context, arg database.Updat
|
||||
return q.db.UpdateChatPlanModeByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return q.db.UpdateChatProvider(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
// UpdateChatStatus is used by the chat processor to change chat status.
|
||||
// It should be called with system context.
|
||||
@@ -7429,17 +7358,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
|
||||
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
|
||||
user, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
@@ -8371,17 +8289,6 @@ func (q *querier) UpsertUserChatPersonalModelOverride(ctx context.Context, arg d
|
||||
return q.db.UpsertUserChatPersonalModelOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpsertUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -557,10 +557,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteChatModelConfigsByProvider(gomock.Any(), providerName).Return(nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
s.Run("DeleteChatModelConfigsByAIProviderID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
providerID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigsByAIProviderID(gomock.Any(), providerID).Return(nil).AnyTimes()
|
||||
check.Args(providerID).Asserts(rbac.ResourceAIProvider, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -972,34 +972,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerName := "test-provider"
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
|
||||
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviderByProviderForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerName := "test-provider"
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
|
||||
dbm.EXPECT().GetChatProviderByProviderForUpdate(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
|
||||
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes()
|
||||
@@ -1112,12 +1085,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
|
||||
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
threshold := dbtime.Now()
|
||||
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
|
||||
@@ -1166,17 +1134,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertChatProviderParams{
|
||||
Provider: "test-provider",
|
||||
DisplayName: "Test Provider",
|
||||
APIKey: "test-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled})
|
||||
dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
|
||||
s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{})
|
||||
@@ -1303,17 +1261,7 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.ChatProvider{})
|
||||
arg := database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: "Updated Provider",
|
||||
APIKey: "updated-api-key",
|
||||
Enabled: true,
|
||||
}
|
||||
dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider)
|
||||
}))
|
||||
|
||||
s.Run("UpdateChatPinOrder", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatPinOrderParams{
|
||||
@@ -2933,36 +2881,7 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
|
||||
}))
|
||||
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
|
||||
}))
|
||||
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
|
||||
}))
|
||||
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
|
||||
s.Run("GetUserAIProviderKeyByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.GetUserAIProviderKeyByProviderIDParams{UserID: u.ID, AIProviderID: uuid.New()}
|
||||
@@ -6582,6 +6501,15 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
dbm.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), provider.ID).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB})
|
||||
}))
|
||||
s.Run("GetAIProviderKeysByProviderIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
providerIDs := []uuid.UUID{providerA.ID, providerB.ID}
|
||||
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerA.ID})
|
||||
keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerB.ID})
|
||||
dbm.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), providerIDs).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes()
|
||||
check.Args(providerIDs).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB})
|
||||
}))
|
||||
s.Run("GetAIProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{})
|
||||
keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{})
|
||||
|
||||
@@ -149,8 +149,29 @@ const (
|
||||
|
||||
func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelConfig, munge ...func(*database.InsertChatModelConfigParams)) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
providerName := takeFirst(seed.Provider, "openai")
|
||||
aiProviderID := seed.AIProviderID
|
||||
if !aiProviderID.Valid {
|
||||
providers, err := db.GetAIProviders(genCtx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
require.NoError(t, err, "get ai providers")
|
||||
var provider database.AIProvider
|
||||
for _, candidate := range providers {
|
||||
if candidate.Type != database.AIProviderType(providerName) {
|
||||
continue
|
||||
}
|
||||
if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) {
|
||||
provider = candidate
|
||||
}
|
||||
}
|
||||
if provider.ID == uuid.Nil {
|
||||
provider = AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AIProviderType(providerName),
|
||||
})
|
||||
}
|
||||
aiProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true}
|
||||
}
|
||||
params := database.InsertChatModelConfigParams{
|
||||
Provider: takeFirst(seed.Provider, "openai"),
|
||||
Provider: providerName,
|
||||
Model: takeFirst(seed.Model, "gpt-4o-mini"),
|
||||
DisplayName: takeFirst(seed.DisplayName, "Test Model"),
|
||||
CreatedBy: seed.CreatedBy,
|
||||
@@ -160,7 +181,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon
|
||||
ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit),
|
||||
CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold),
|
||||
Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)),
|
||||
AIProviderID: seed.AIProviderID,
|
||||
AIProviderID: aiProviderID,
|
||||
}
|
||||
for _, fn := range munge {
|
||||
fn(¶ms)
|
||||
@@ -263,9 +284,36 @@ func ChatProvider(t testing.TB, db database.Store, seed database.ChatProvider, m
|
||||
for _, fn := range munge {
|
||||
fn(¶ms)
|
||||
}
|
||||
provider, err := db.InsertChatProvider(genCtx, params)
|
||||
require.NoError(t, err, "insert chat provider")
|
||||
return provider
|
||||
provider := AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AIProviderType(params.Provider),
|
||||
Name: "test-" + uuid.NewString(),
|
||||
DisplayName: sql.NullString{String: params.DisplayName, Valid: params.DisplayName != ""},
|
||||
BaseUrl: params.BaseUrl,
|
||||
}, func(p *database.InsertAIProviderParams) {
|
||||
p.Enabled = params.Enabled
|
||||
})
|
||||
if params.APIKey != "" {
|
||||
AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: provider.ID,
|
||||
APIKey: params.APIKey,
|
||||
ApiKeyKeyID: params.ApiKeyKeyID,
|
||||
})
|
||||
}
|
||||
return database.ChatProvider{
|
||||
ID: provider.ID,
|
||||
Provider: params.Provider,
|
||||
DisplayName: params.DisplayName,
|
||||
APIKey: params.APIKey,
|
||||
BaseUrl: params.BaseUrl,
|
||||
ApiKeyKeyID: params.ApiKeyKeyID,
|
||||
CreatedBy: params.CreatedBy,
|
||||
Enabled: params.Enabled,
|
||||
CentralApiKeyEnabled: params.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: params.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: params.AllowCentralApiKeyFallback,
|
||||
CreatedAt: provider.CreatedAt,
|
||||
UpdatedAt: provider.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerConfig) database.MCPServerConfig {
|
||||
|
||||
@@ -465,6 +465,14 @@ func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uui
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatModelConfigsByAIProviderID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigsByAIProviderID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigsByProvider(ctx, provider)
|
||||
@@ -473,14 +481,6 @@ func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatQueuedMessage(ctx, arg)
|
||||
@@ -809,14 +809,6 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -1089,6 +1081,14 @@ func (m queryMetricsStore) GetAIProviderKeysByProviderID(ctx context.Context, pr
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviderKeysByProviderIDs(ctx, providerIds)
|
||||
m.queryLatencies.WithLabelValues("GetAIProviderKeysByProviderIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeysByProviderIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviders(ctx, arg)
|
||||
@@ -1537,46 +1537,6 @@ func (m queryMetricsStore) GetChatPlanModeInstructions(ctx context.Context) (str
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByIDForUpdate(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByIDForUpdate").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByIDForUpdate").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByProvider(ctx, provider)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByProviderForUpdate(ctx, provider)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviderByProviderForUpdate").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProviderForUpdate").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID)
|
||||
@@ -1833,14 +1793,6 @@ func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledChatProviders(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
|
||||
@@ -3033,14 +2985,6 @@ func (m queryMetricsStore) GetUserChatPersonalModelOverride(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
@@ -3825,14 +3769,6 @@ func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg)
|
||||
@@ -4881,14 +4817,6 @@ func (m queryMetricsStore) UpdateChatPlanModeByID(ctx context.Context, arg datab
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatProvider(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatStatus(ctx, arg)
|
||||
@@ -5321,14 +5249,6 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserCodeDiffDisplayMode(ctx, arg)
|
||||
@@ -6105,14 +6025,6 @@ func (m queryMetricsStore) UpsertUserChatPersonalModelOverride(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
|
||||
|
||||
@@ -760,6 +760,20 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigsByAIProviderID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatModelConfigsByAIProviderID", ctx, aiProviderID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigsByAIProviderID indicates an expected call of DeleteChatModelConfigsByAIProviderID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByAIProviderID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByAIProviderID), ctx, aiProviderID)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigsByProvider mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -774,20 +788,6 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByProvider(ctx, provider
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByProvider", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByProvider), ctx, provider)
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID mocks base method.
|
||||
func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1376,20 +1376,6 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1889,6 +1875,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderID(ctx, providerID a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderID), ctx, providerID)
|
||||
}
|
||||
|
||||
// GetAIProviderKeysByProviderIDs mocks base method.
|
||||
func (m *MockStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAIProviderKeysByProviderIDs", ctx, providerIds)
|
||||
ret0, _ := ret[0].([]database.AIProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAIProviderKeysByProviderIDs indicates an expected call of GetAIProviderKeysByProviderIDs.
|
||||
func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderIDs(ctx, providerIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderIDs", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderIDs), ctx, providerIds)
|
||||
}
|
||||
|
||||
// GetAIProviders mocks base method.
|
||||
func (m *MockStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2849,81 +2850,6 @@ func (mr *MockStoreMockRecorder) GetChatPlanModeInstructions(ctx any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).GetChatPlanModeInstructions), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByID indicates an expected call of GetChatProviderByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatProviderByIDForUpdate mocks base method.
|
||||
func (m *MockStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByIDForUpdate", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByIDForUpdate indicates an expected call of GetChatProviderByIDForUpdate.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByIDForUpdate(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByIDForUpdate), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider mocks base method.
|
||||
func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider)
|
||||
}
|
||||
|
||||
// GetChatProviderByProviderForUpdate mocks base method.
|
||||
func (m *MockStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviderByProviderForUpdate", ctx, provider)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviderByProviderForUpdate indicates an expected call of GetChatProviderByProviderForUpdate.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviderByProviderForUpdate(ctx, provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProviderForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProviderForUpdate), ctx, provider)
|
||||
}
|
||||
|
||||
// GetChatProviders mocks base method.
|
||||
func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatProviders indicates an expected call of GetChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetChatQueuedMessages mocks base method.
|
||||
func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3404,21 +3330,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders mocks base method.
|
||||
func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx)
|
||||
ret0, _ := ret[0].([]database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5684,21 +5595,6 @@ func (mr *MockStoreMockRecorder) GetUserChatPersonalModelOverride(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).GetUserChatPersonalModelOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7183,21 +7079,6 @@ func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatProvider mocks base method.
|
||||
func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatProvider indicates an expected call of InsertChatProvider.
|
||||
func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatQueuedMessage mocks base method.
|
||||
func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9234,21 +9115,6 @@ func (mr *MockStoreMockRecorder) UpdateChatPlanModeByID(ctx, arg any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPlanModeByID", reflect.TypeOf((*MockStore)(nil).UpdateChatPlanModeByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatProvider mocks base method.
|
||||
func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatProvider indicates an expected call of UpdateChatProvider.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatStatus mocks base method.
|
||||
func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -10035,21 +9901,6 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserCodeDiffDisplayMode mocks base method.
|
||||
func (m *MockStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -11446,21 +11297,6 @@ func (mr *MockStoreMockRecorder) UpsertUserChatPersonalModelOverride(ctx, arg an
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertUserChatPersonalModelOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertWebpushVAPIDKeys mocks base method.
|
||||
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+1
-60
@@ -1531,30 +1531,11 @@ CREATE TABLE chat_model_configs (
|
||||
compression_threshold integer NOT NULL,
|
||||
options jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
ai_provider_id uuid,
|
||||
CONSTRAINT chat_model_configs_ai_provider_required_when_active CHECK (((deleted = true) OR (ai_provider_id IS NOT NULL))),
|
||||
CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))),
|
||||
CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0))
|
||||
);
|
||||
|
||||
CREATE TABLE chat_providers (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
provider text NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
api_key text DEFAULT ''::text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_by uuid,
|
||||
enabled boolean DEFAULT true NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
central_api_key_enabled boolean DEFAULT true NOT NULL,
|
||||
allow_user_api_key boolean DEFAULT false NOT NULL,
|
||||
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
|
||||
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
|
||||
CREATE TABLE chat_queued_messages (
|
||||
id bigint NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
@@ -3040,17 +3021,6 @@ COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to a
|
||||
|
||||
COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.';
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
chat_provider_id uuid NOT NULL,
|
||||
api_key text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
|
||||
);
|
||||
|
||||
CREATE TABLE user_configs (
|
||||
user_id uuid NOT NULL,
|
||||
key character varying(256) NOT NULL,
|
||||
@@ -3667,12 +3637,6 @@ ALTER TABLE ONLY chat_messages
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3889,12 +3853,6 @@ ALTER TABLE ONLY user_ai_provider_keys
|
||||
ALTER TABLE ONLY user_ai_provider_keys
|
||||
ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
|
||||
@@ -4112,8 +4070,6 @@ CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING b
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
|
||||
|
||||
CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
@@ -4444,12 +4400,6 @@ ALTER TABLE ONLY chat_model_configs
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY chat_providers
|
||||
ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4687,15 +4637,6 @@ ALTER TABLE ONLY user_ai_provider_keys
|
||||
ALTER TABLE ONLY user_ai_provider_keys
|
||||
ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ const (
|
||||
ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id);
|
||||
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
@@ -105,9 +103,6 @@ const (
|
||||
ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ChatProvider is the fixture shape accepted by dbgen.ChatProvider.
|
||||
//
|
||||
//nolint:revive
|
||||
type ChatProvider struct {
|
||||
ID uuid.UUID
|
||||
Provider string
|
||||
DisplayName string
|
||||
APIKey string
|
||||
BaseUrl string
|
||||
ApiKeyKeyID sql.NullString
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
CreatedBy uuid.NullUUID
|
||||
Enabled bool
|
||||
CentralApiKeyEnabled bool
|
||||
AllowUserApiKey bool
|
||||
AllowCentralApiKeyFallback bool
|
||||
}
|
||||
|
||||
// InsertChatProviderParams is the callback parameter shape accepted by
|
||||
// dbgen.ChatProvider.
|
||||
//
|
||||
//nolint:revive
|
||||
type InsertChatProviderParams struct {
|
||||
Provider string
|
||||
DisplayName string
|
||||
APIKey string
|
||||
BaseUrl string
|
||||
ApiKeyKeyID sql.NullString
|
||||
CreatedBy uuid.NullUUID
|
||||
Enabled bool
|
||||
CentralApiKeyEnabled bool
|
||||
AllowUserApiKey bool
|
||||
AllowCentralApiKeyFallback bool
|
||||
}
|
||||
@@ -1,8 +1,15 @@
|
||||
DROP TABLE IF EXISTS user_chat_provider_keys;
|
||||
|
||||
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
|
||||
DO $$
|
||||
BEGIN
|
||||
IF to_regclass('chat_providers') IS NULL THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
DROP COLUMN IF EXISTS central_api_key_enabled,
|
||||
DROP COLUMN IF EXISTS allow_user_api_key,
|
||||
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
|
||||
END $$;
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
-- Restore placeholder provider rows before re-adding the provider FK.
|
||||
--
|
||||
-- The companion up migration dropped chat_model_configs.provider's foreign
|
||||
-- key, so historical model-config rows can outlive a deleted provider row.
|
||||
-- These backfilled providers are deliberately disabled stubs with empty
|
||||
-- credential fields, which lets rollback restore referential integrity
|
||||
-- without re-enabling a provider. This insert depends on the current
|
||||
-- provider whitelist still admitting every historical
|
||||
-- chat_model_configs.provider value, and on the omitted columns keeping
|
||||
-- compatible defaults. Operators restoring a real provider should update the
|
||||
-- stub row, including credential-policy flags such as
|
||||
-- central_api_key_enabled, before enabling it, rather than insert a second
|
||||
-- row with the same provider name.
|
||||
INSERT INTO chat_providers (provider, enabled)
|
||||
SELECT DISTINCT
|
||||
DO $$
|
||||
BEGIN
|
||||
IF to_regclass('chat_providers') IS NULL THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Restore placeholder provider rows before re-adding the provider FK.
|
||||
--
|
||||
-- The companion up migration dropped chat_model_configs.provider's foreign
|
||||
-- key, so historical model-config rows can outlive a deleted provider row.
|
||||
-- These backfilled providers are deliberately disabled stubs with empty
|
||||
-- credential fields, which lets rollback restore referential integrity
|
||||
-- without re-enabling a provider. This insert depends on the current
|
||||
-- provider whitelist still admitting every historical
|
||||
-- chat_model_configs.provider value, and on the omitted columns keeping
|
||||
-- compatible defaults. Operators restoring a real provider should update the
|
||||
-- stub row, including credential-policy flags such as
|
||||
-- central_api_key_enabled, before enabling it, rather than insert a second
|
||||
-- row with the same provider name.
|
||||
INSERT INTO chat_providers (provider, enabled)
|
||||
SELECT DISTINCT
|
||||
cmc.provider,
|
||||
FALSE
|
||||
FROM
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
LEFT JOIN
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
WHERE
|
||||
WHERE
|
||||
cp.provider IS NULL;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_provider_fkey
|
||||
FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE;
|
||||
END $$;
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
WITH migrated_provider_ids AS (
|
||||
SELECT id
|
||||
FROM chat_providers
|
||||
UNION
|
||||
SELECT id
|
||||
FROM ai_providers
|
||||
WHERE name LIKE 'agents-%'
|
||||
AND deleted = TRUE
|
||||
)
|
||||
UPDATE chat_model_configs
|
||||
SET ai_provider_id = NULL
|
||||
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
DO $$
|
||||
BEGIN
|
||||
IF to_regclass('chat_providers') IS NULL THEN
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
WITH migrated_provider_ids AS (
|
||||
WITH migrated_provider_ids AS (
|
||||
SELECT id
|
||||
FROM chat_providers
|
||||
UNION
|
||||
@@ -19,11 +12,12 @@ WITH migrated_provider_ids AS (
|
||||
FROM ai_providers
|
||||
WHERE name LIKE 'agents-%'
|
||||
AND deleted = TRUE
|
||||
)
|
||||
DELETE FROM user_ai_provider_keys
|
||||
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
)
|
||||
UPDATE chat_model_configs
|
||||
SET ai_provider_id = NULL
|
||||
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
|
||||
WITH migrated_provider_ids AS (
|
||||
WITH migrated_provider_ids AS (
|
||||
SELECT id
|
||||
FROM chat_providers
|
||||
UNION
|
||||
@@ -31,11 +25,11 @@ WITH migrated_provider_ids AS (
|
||||
FROM ai_providers
|
||||
WHERE name LIKE 'agents-%'
|
||||
AND deleted = TRUE
|
||||
)
|
||||
DELETE FROM ai_provider_keys
|
||||
WHERE provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
)
|
||||
DELETE FROM user_ai_provider_keys
|
||||
WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
|
||||
WITH migrated_provider_ids AS (
|
||||
WITH migrated_provider_ids AS (
|
||||
SELECT id
|
||||
FROM chat_providers
|
||||
UNION
|
||||
@@ -43,6 +37,19 @@ WITH migrated_provider_ids AS (
|
||||
FROM ai_providers
|
||||
WHERE name LIKE 'agents-%'
|
||||
AND deleted = TRUE
|
||||
)
|
||||
DELETE FROM ai_providers
|
||||
WHERE id IN (SELECT id FROM migrated_provider_ids);
|
||||
)
|
||||
DELETE FROM ai_provider_keys
|
||||
WHERE provider_id IN (SELECT id FROM migrated_provider_ids);
|
||||
|
||||
WITH migrated_provider_ids AS (
|
||||
SELECT id
|
||||
FROM chat_providers
|
||||
UNION
|
||||
SELECT id
|
||||
FROM ai_providers
|
||||
WHERE name LIKE 'agents-%'
|
||||
AND deleted = TRUE
|
||||
)
|
||||
DELETE FROM ai_providers
|
||||
WHERE id IN (SELECT id FROM migrated_provider_ids);
|
||||
END $$;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
-- no-op. Legacy chat provider tables are intentionally not recreated from AI
|
||||
-- provider definitions. Rolling back past this migration is not reversible at
|
||||
-- the schema level.
|
||||
@@ -0,0 +1,140 @@
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM chat_providers cp
|
||||
JOIN ai_providers ap ON ap.name = 'agents-' || cp.provider
|
||||
WHERE ap.deleted = FALSE
|
||||
AND ap.id != cp.id
|
||||
) THEN
|
||||
RAISE EXCEPTION 'cannot finalize chat provider migration because a live agents-* AI provider name already exists';
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
INSERT INTO ai_providers (
|
||||
id,
|
||||
type,
|
||||
name,
|
||||
display_name,
|
||||
enabled,
|
||||
base_url,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
cp.id,
|
||||
cp.provider::ai_provider_type,
|
||||
'agents-' || cp.provider,
|
||||
NULLIF(cp.display_name, ''),
|
||||
cp.enabled,
|
||||
cp.base_url,
|
||||
cp.created_at,
|
||||
cp.updated_at
|
||||
FROM chat_providers cp
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM ai_providers ap
|
||||
WHERE ap.id = cp.id
|
||||
);
|
||||
|
||||
UPDATE ai_providers ap
|
||||
SET
|
||||
type = cp.provider::ai_provider_type,
|
||||
name = 'agents-' || cp.provider,
|
||||
display_name = NULLIF(cp.display_name, ''),
|
||||
enabled = cp.enabled,
|
||||
deleted = FALSE,
|
||||
base_url = cp.base_url,
|
||||
updated_at = GREATEST(cp.updated_at, ap.updated_at)
|
||||
FROM chat_providers cp
|
||||
WHERE ap.id = cp.id
|
||||
AND (cp.updated_at > ap.updated_at OR ap.deleted);
|
||||
|
||||
DELETE FROM ai_provider_keys apk
|
||||
USING chat_providers cp
|
||||
WHERE cp.id = apk.provider_id
|
||||
AND cp.api_key = ''
|
||||
AND cp.updated_at > apk.updated_at;
|
||||
|
||||
WITH runtime_provider_keys AS (
|
||||
SELECT DISTINCT ON (apk.provider_id)
|
||||
apk.id,
|
||||
apk.provider_id
|
||||
FROM ai_provider_keys apk
|
||||
JOIN chat_providers cp ON cp.id = apk.provider_id
|
||||
WHERE cp.api_key != ''
|
||||
ORDER BY
|
||||
apk.provider_id ASC,
|
||||
apk.created_at ASC,
|
||||
apk.id ASC
|
||||
)
|
||||
UPDATE ai_provider_keys apk
|
||||
SET
|
||||
api_key = cp.api_key,
|
||||
api_key_key_id = cp.api_key_key_id,
|
||||
updated_at = cp.updated_at
|
||||
FROM runtime_provider_keys rpk
|
||||
JOIN chat_providers cp ON cp.id = rpk.provider_id
|
||||
WHERE apk.id = rpk.id
|
||||
AND cp.updated_at > apk.updated_at;
|
||||
|
||||
INSERT INTO ai_provider_keys (
|
||||
id,
|
||||
provider_id,
|
||||
api_key,
|
||||
api_key_key_id,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
cp.id,
|
||||
cp.api_key,
|
||||
cp.api_key_key_id,
|
||||
cp.updated_at,
|
||||
cp.updated_at
|
||||
FROM chat_providers cp
|
||||
WHERE cp.api_key != ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM ai_provider_keys apk
|
||||
WHERE apk.provider_id = cp.id
|
||||
);
|
||||
|
||||
INSERT INTO user_ai_provider_keys (
|
||||
id,
|
||||
user_id,
|
||||
ai_provider_id,
|
||||
api_key,
|
||||
api_key_key_id,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
ucpk.id,
|
||||
ucpk.user_id,
|
||||
ucpk.chat_provider_id,
|
||||
ucpk.api_key,
|
||||
ucpk.api_key_key_id,
|
||||
ucpk.created_at,
|
||||
ucpk.updated_at
|
||||
FROM user_chat_provider_keys ucpk
|
||||
ON CONFLICT (user_id, ai_provider_id) DO UPDATE
|
||||
SET
|
||||
api_key = EXCLUDED.api_key,
|
||||
api_key_key_id = EXCLUDED.api_key_key_id,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
WHERE user_ai_provider_keys.updated_at < EXCLUDED.updated_at;
|
||||
|
||||
UPDATE chat_model_configs cmc
|
||||
SET ai_provider_id = cp.id
|
||||
FROM chat_providers cp
|
||||
WHERE cmc.provider = cp.provider
|
||||
AND cmc.ai_provider_id IS NULL;
|
||||
|
||||
ALTER TABLE chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_ai_provider_required_when_active
|
||||
CHECK (deleted = TRUE OR ai_provider_id IS NOT NULL);
|
||||
|
||||
DROP TABLE IF EXISTS user_chat_provider_keys;
|
||||
DROP TABLE IF EXISTS chat_providers;
|
||||
@@ -4682,23 +4682,6 @@ type ChatModelConfig struct {
|
||||
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
|
||||
}
|
||||
|
||||
type ChatProvider struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
ID int64 `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
@@ -5706,16 +5689,6 @@ type UserAiProviderKey struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserChatProviderKey struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
|
||||
@@ -122,8 +122,8 @@ type sqlcQuerier interface {
|
||||
// archive-cleanup retry).
|
||||
DeleteChatDebugDataByChatID(ctx context.Context, arg DeleteChatDebugDataByChatIDParams) (int64, error)
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error
|
||||
DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error
|
||||
DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error
|
||||
@@ -196,7 +196,6 @@ type sqlcQuerier interface {
|
||||
DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error
|
||||
DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
DeleteUserSkillByUserIDAndName(ctx context.Context, arg DeleteUserSkillByUserIDAndNameParams) (UserSkill, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
@@ -272,6 +271,9 @@ type sqlcQuerier interface {
|
||||
// key per provider; multiple keys are stored to support future
|
||||
// failover and rotation flows.
|
||||
GetAIProviderKeysByProviderID(ctx context.Context, providerID uuid.UUID) ([]AIProviderKey, error)
|
||||
// Returns all keys for the requested providers, ordered by provider then created_at ASC
|
||||
// so callers can select the oldest non-empty key per provider without issuing N queries.
|
||||
GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error)
|
||||
// Returns AI provider rows. Soft-deleted and disabled rows are excluded
|
||||
// unless include_deleted or include_disabled is set.
|
||||
GetAIProviders(ctx context.Context, arg GetAIProvidersParams) ([]AIProvider, error)
|
||||
@@ -384,11 +386,6 @@ type sqlcQuerier interface {
|
||||
// personal chat model overrides. It defaults to false when unset.
|
||||
GetChatPersonalModelOverridesEnabled(ctx context.Context) (bool, error)
|
||||
GetChatPlanModeInstructions(ctx context.Context) (string, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
// Returns the chat retention period in days. Chats archived longer
|
||||
// than this and orphaned chat files older than this are purged by
|
||||
@@ -452,7 +449,6 @@ type sqlcQuerier interface {
|
||||
// Check both to ensure the selected config is actually usable.
|
||||
GetEnabledChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
|
||||
GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error)
|
||||
@@ -760,7 +756,6 @@ type sqlcQuerier interface {
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
|
||||
GetUserChatPersonalModelOverride(ctx context.Context, arg GetUserChatPersonalModelOverrideParams) (string, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
// Returns the total spend for a user in the given period.
|
||||
// When organization_id is NULL, spend across all organizations is
|
||||
// returned (global behavior). Otherwise only spend within the
|
||||
@@ -932,7 +927,6 @@ type sqlcQuerier interface {
|
||||
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
|
||||
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
|
||||
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
|
||||
InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error)
|
||||
InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error)
|
||||
@@ -1214,7 +1208,6 @@ type sqlcQuerier interface {
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error
|
||||
UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error)
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error)
|
||||
UpdateChatTitleByID(ctx context.Context, arg UpdateChatTitleByIDParams) (Chat, error)
|
||||
@@ -1283,7 +1276,6 @@ type sqlcQuerier interface {
|
||||
UpdateUserAgentChatSendShortcut(ctx context.Context, arg UpdateUserAgentChatSendShortcutParams) (string, error)
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpdateUserCodeDiffDisplayMode(ctx context.Context, arg UpdateUserCodeDiffDisplayModeParams) (string, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
@@ -1408,7 +1400,6 @@ type sqlcQuerier interface {
|
||||
UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error)
|
||||
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
|
||||
UpsertUserChatPersonalModelOverride(ctx context.Context, arg UpsertUserChatPersonalModelOverrideParams) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
|
||||
|
||||
@@ -10609,20 +10609,12 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
|
||||
}, func(params *database.InsertChatModelConfigParams) {
|
||||
params.Enabled = false
|
||||
})
|
||||
legacyProvider := dbgen.ChatProvider(t, store, database.ChatProvider{Provider: "google"})
|
||||
legacyConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
|
||||
Provider: legacyProvider.Provider,
|
||||
Model: "google-model-" + uuid.NewString(),
|
||||
})
|
||||
|
||||
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == enabledConfig.ID
|
||||
}))
|
||||
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == legacyConfig.ID
|
||||
}))
|
||||
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == disabledProviderConfig.ID
|
||||
}))
|
||||
@@ -10636,6 +10628,46 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
|
||||
|
||||
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledProviderConfig.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledModelConfig.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
func insertChatModelConfigForTest(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
store database.Store,
|
||||
params database.InsertChatModelConfigParams,
|
||||
) (database.ChatModelConfig, error) {
|
||||
t.Helper()
|
||||
if params.AIProviderID.Valid {
|
||||
return store.InsertChatModelConfig(ctx, params)
|
||||
}
|
||||
providerName := params.Provider
|
||||
if providerName == "" {
|
||||
providerName = "openai"
|
||||
params.Provider = providerName
|
||||
}
|
||||
providers, err := store.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, err
|
||||
}
|
||||
var provider database.AIProvider
|
||||
for _, candidate := range providers {
|
||||
if candidate.Type != database.AIProviderType(providerName) {
|
||||
continue
|
||||
}
|
||||
if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) {
|
||||
provider = candidate
|
||||
}
|
||||
}
|
||||
if provider.ID == uuid.Nil {
|
||||
provider = dbgen.AIProvider(t, store, database.AIProvider{
|
||||
Type: database.AIProviderType(providerName),
|
||||
})
|
||||
}
|
||||
params.AIProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true}
|
||||
return store.InsertChatModelConfig(ctx, params)
|
||||
}
|
||||
|
||||
func TestInsertChatMessages(t *testing.T) {
|
||||
@@ -10653,7 +10685,7 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelConfig, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
DisplayName: displayName,
|
||||
@@ -10681,14 +10713,13 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
provider := "openai"
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: provider,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelConfigA := insertModelConfig(
|
||||
t,
|
||||
@@ -10850,18 +10881,21 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
|
||||
// A chat_providers row is required as a FK for model configs.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
// An AI provider row is required as a FK for model configs.
|
||||
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "test-" + uuid.NewString(),
|
||||
DisplayName: sql.NullString{String: "OpenAI", Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: provider.ID,
|
||||
APIKey: "test-key",
|
||||
})
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
@@ -11227,16 +11261,15 @@ func TestGetPRInsights(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
mc, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
DisplayName: "Claude 4",
|
||||
@@ -11683,7 +11716,7 @@ func TestGetPRInsights(t *testing.T) {
|
||||
store, userID, _, orgID := setupChatInfra(t)
|
||||
|
||||
const modelName = "claude-4.1"
|
||||
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
|
||||
emptyDisplayModel, err := insertChatModelConfigForTest(context.Background(), t, store, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: modelName,
|
||||
DisplayName: "",
|
||||
@@ -11791,16 +11824,15 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
// Use background context for fixture setup so the
|
||||
// timed test context doesn't tick during DB init.
|
||||
bg := context.Background()
|
||||
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
@@ -11972,16 +12004,15 @@ func TestChatPinOrderConstraints(t *testing.T) {
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
||||
|
||||
bg := context.Background()
|
||||
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
@@ -12065,16 +12096,15 @@ func TestChatLabels(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
||||
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
@@ -12365,16 +12395,15 @@ func TestUpdateChatLastTurnSummary(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
||||
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
@@ -12466,16 +12495,15 @@ func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -12659,16 +12687,15 @@ func TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls(t *te
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-step-boundaries-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -12917,16 +12944,15 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-finalize-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13356,16 +13382,15 @@ func TestChatDebugSQLGuards(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-guards-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13490,16 +13515,15 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-coalesce-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13605,16 +13629,15 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-step-coalesce-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13730,16 +13753,15 @@ func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-null-msg-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13828,16 +13850,15 @@ func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testi
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-started-before-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -13940,16 +13961,15 @@ func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T)
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-by-chat-started-before-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
@@ -14029,17 +14049,13 @@ func TestGetChatsFilter(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
provider := dbgen.AIProviderWithOptionalKey(t, store, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
}, "test-key")
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
Model: "test-model-" + uuid.NewString(),
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
@@ -14265,16 +14281,15 @@ func TestChatHasUnread(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
dbgen.ChatProvider(t, store, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model-" + uuid.NewString(),
|
||||
DisplayName: "Test Model",
|
||||
|
||||
+68
-497
@@ -276,6 +276,51 @@ func (q *sqlQuerier) GetAIProviderKeysByProviderID(ctx context.Context, provider
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getAIProviderKeysByProviderIDs = `-- name: GetAIProviderKeysByProviderIDs :many
|
||||
SELECT
|
||||
id, provider_id, api_key, api_key_key_id, created_at, updated_at
|
||||
FROM
|
||||
ai_provider_keys
|
||||
WHERE
|
||||
provider_id = ANY($1::uuid[])
|
||||
ORDER BY
|
||||
provider_id ASC,
|
||||
created_at ASC,
|
||||
id ASC
|
||||
`
|
||||
|
||||
// Returns all keys for the requested providers, ordered by provider then created_at ASC
|
||||
// so callers can select the oldest non-empty key per provider without issuing N queries.
|
||||
func (q *sqlQuerier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAIProviderKeysByProviderIDs, pq.Array(providerIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AIProviderKey
|
||||
for rows.Next() {
|
||||
var i AIProviderKey
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertAIProviderKey = `-- name: InsertAIProviderKey :one
|
||||
INSERT INTO ai_provider_keys (
|
||||
id,
|
||||
@@ -5020,6 +5065,23 @@ func (q *sqlQuerier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteChatModelConfigsByAIProviderID = `-- name: DeleteChatModelConfigsByAIProviderID :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
deleted = TRUE,
|
||||
deleted_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
ai_provider_id = $1::uuid
|
||||
AND deleted = FALSE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteChatModelConfigsByAIProviderID, aiProviderID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteChatModelConfigsByProvider = `-- name: DeleteChatModelConfigsByProvider :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
@@ -5164,18 +5226,14 @@ SELECT
|
||||
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
LEFT JOIN
|
||||
JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.id = $1::uuid
|
||||
AND cmc.deleted = FALSE
|
||||
AND cmc.enabled = TRUE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
AND ap.enabled = TRUE
|
||||
AND ap.deleted = FALSE
|
||||
`
|
||||
|
||||
// Providers can be disabled independently of their model configs.
|
||||
@@ -5209,17 +5267,13 @@ SELECT
|
||||
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
LEFT JOIN
|
||||
JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cmc.deleted = FALSE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
AND ap.enabled = TRUE
|
||||
AND ap.deleted = FALSE
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
@@ -5435,369 +5489,6 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteChatProviderByID = `-- name: DeleteChatProviderByID :exec
|
||||
DELETE FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteChatProviderByID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getChatProviderByID = `-- name: GetChatProviderByID :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatProviderByID, id)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviderByIDForUpdate = `-- name: GetChatProviderByIDForUpdate :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatProviderByIDForUpdate, id)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = $1::text
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatProviderByProvider, provider)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviderByProviderForUpdate = `-- name: GetChatProviderByProviderForUpdate :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = $1::text
|
||||
FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatProviderByProviderForUpdate, provider)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviders = `-- name: GetChatProviders :many
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
ORDER BY
|
||||
provider ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatProviders)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatProvider
|
||||
for rows.Next() {
|
||||
var i ChatProvider
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
ORDER BY
|
||||
provider ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getEnabledChatProviders)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatProvider
|
||||
for rows.Next() {
|
||||
var i ChatProvider
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertChatProvider = `-- name: InsertChatProvider :one
|
||||
INSERT INTO chat_providers (
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
$1::text,
|
||||
$2::text,
|
||||
$3::text,
|
||||
$4::text,
|
||||
$5::text,
|
||||
$6::uuid,
|
||||
$7::boolean,
|
||||
$8::boolean,
|
||||
$9::boolean,
|
||||
$10::boolean
|
||||
)
|
||||
RETURNING
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
`
|
||||
|
||||
type InsertChatProviderParams struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatProvider,
|
||||
arg.Provider,
|
||||
arg.DisplayName,
|
||||
arg.APIKey,
|
||||
arg.BaseUrl,
|
||||
arg.ApiKeyKeyID,
|
||||
arg.CreatedBy,
|
||||
arg.Enabled,
|
||||
arg.CentralApiKeyEnabled,
|
||||
arg.AllowUserApiKey,
|
||||
arg.AllowCentralApiKeyFallback,
|
||||
)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatProvider = `-- name: UpdateChatProvider :one
|
||||
UPDATE
|
||||
chat_providers
|
||||
SET
|
||||
display_name = $1::text,
|
||||
api_key = $2::text,
|
||||
base_url = $3::text,
|
||||
api_key_key_id = $4::text,
|
||||
enabled = $5::boolean,
|
||||
central_api_key_enabled = $6::boolean,
|
||||
allow_user_api_key = $7::boolean,
|
||||
allow_central_api_key_fallback = $8::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $9::uuid
|
||||
RETURNING
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
`
|
||||
|
||||
type UpdateChatProviderParams struct {
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatProvider,
|
||||
arg.DisplayName,
|
||||
arg.APIKey,
|
||||
arg.BaseUrl,
|
||||
arg.ApiKeyKeyID,
|
||||
arg.Enabled,
|
||||
arg.CentralApiKeyEnabled,
|
||||
arg.AllowUserApiKey,
|
||||
arg.AllowCentralApiKeyFallback,
|
||||
arg.ID,
|
||||
)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Provider,
|
||||
&i.DisplayName,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedBy,
|
||||
&i.Enabled,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const acquireChats = `-- name: AcquireChats :many
|
||||
WITH acquired_chats AS (
|
||||
UPDATE
|
||||
@@ -27555,126 +27246,6 @@ func (q *sqlQuerier) UpdateUserSkillByUserIDAndName(ctx context.Context, arg Upd
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2
|
||||
`
|
||||
|
||||
type DeleteUserChatProviderKeyParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many
|
||||
SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserChatProviderKey
|
||||
for rows.Next() {
|
||||
var i UserChatProviderKey
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW()
|
||||
WHERE user_id = $3 AND chat_provider_id = $4
|
||||
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateUserChatProviderKeyParams struct {
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserChatProviderKey,
|
||||
arg.APIKey,
|
||||
arg.ApiKeyKeyID,
|
||||
arg.UserID,
|
||||
arg.ChatProviderID,
|
||||
)
|
||||
var i UserChatProviderKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES ($1, $2, $3, $4::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = $3,
|
||||
api_key_key_id = $4::text,
|
||||
updated_at = NOW()
|
||||
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertUserChatProviderKeyParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) {
|
||||
row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey,
|
||||
arg.UserID,
|
||||
arg.ChatProviderID,
|
||||
arg.APIKey,
|
||||
arg.ApiKeyKeyID,
|
||||
)
|
||||
var i UserChatProviderKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const allUserIDs = `-- name: AllUserIDs :many
|
||||
SELECT DISTINCT id FROM USERS
|
||||
WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END
|
||||
|
||||
@@ -32,6 +32,20 @@ WHERE
|
||||
ORDER BY
|
||||
provider_id ASC;
|
||||
|
||||
-- name: GetAIProviderKeysByProviderIDs :many
|
||||
-- Returns all keys for the requested providers, ordered by provider then created_at ASC
|
||||
-- so callers can select the oldest non-empty key per provider without issuing N queries.
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
ai_provider_keys
|
||||
WHERE
|
||||
provider_id = ANY(@provider_ids::uuid[])
|
||||
ORDER BY
|
||||
provider_id ASC,
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetAIProviderKeys :many
|
||||
-- Returns AI provider key rows. By default, only rows whose parent
|
||||
-- provider is live (deleted = FALSE) are returned, so the API list
|
||||
|
||||
@@ -34,17 +34,13 @@ SELECT
|
||||
cmc.*
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
LEFT JOIN
|
||||
JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cmc.deleted = FALSE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
AND ap.enabled = TRUE
|
||||
AND ap.deleted = FALSE
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
@@ -58,18 +54,14 @@ FROM
|
||||
chat_model_configs cmc
|
||||
-- Providers can be disabled independently of their model configs.
|
||||
-- Check both to ensure the selected config is actually usable.
|
||||
LEFT JOIN
|
||||
JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.id = @id::uuid
|
||||
AND cmc.deleted = FALSE
|
||||
AND cmc.enabled = TRUE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
);
|
||||
AND ap.enabled = TRUE
|
||||
AND ap.deleted = FALSE;
|
||||
|
||||
-- name: InsertChatModelConfig :one
|
||||
INSERT INTO chat_model_configs (
|
||||
@@ -151,3 +143,14 @@ SET
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
AND deleted = FALSE;
|
||||
|
||||
-- name: DeleteChatModelConfigsByAIProviderID :exec
|
||||
UPDATE
|
||||
chat_model_configs
|
||||
SET
|
||||
deleted = TRUE,
|
||||
deleted_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
ai_provider_id = @ai_provider_id::uuid
|
||||
AND deleted = FALSE;
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
-- name: GetChatProviderByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
|
||||
-- name: GetChatProviderByIDForUpdate :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
FOR UPDATE;
|
||||
|
||||
-- name: GetChatProviderByProvider :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = @provider::text;
|
||||
|
||||
-- name: GetChatProviderByProviderForUpdate :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
provider = @provider::text
|
||||
FOR UPDATE;
|
||||
|
||||
-- name: GetChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: GetEnabledChatProviders :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
enabled = TRUE
|
||||
ORDER BY
|
||||
provider ASC;
|
||||
|
||||
-- name: InsertChatProvider :one
|
||||
INSERT INTO chat_providers (
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@api_key::text,
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@enabled::boolean,
|
||||
@central_api_key_enabled::boolean,
|
||||
@allow_user_api_key::boolean,
|
||||
@allow_central_api_key_fallback::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatProvider :one
|
||||
UPDATE
|
||||
chat_providers
|
||||
SET
|
||||
display_name = @display_name::text,
|
||||
api_key = @api_key::text,
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
central_api_key_enabled = @central_api_key_enabled::boolean,
|
||||
allow_user_api_key = @allow_user_api_key::boolean,
|
||||
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: DeleteChatProviderByID :exec
|
||||
DELETE FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
id = @id::uuid;
|
||||
@@ -1,20 +0,0 @@
|
||||
-- name: GetUserChatProviderKeys :many
|
||||
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
|
||||
|
||||
-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = @api_key,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
|
||||
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
|
||||
@@ -25,8 +25,6 @@ const (
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id);
|
||||
UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider);
|
||||
UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
@@ -99,8 +97,6 @@ const (
|
||||
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id);
|
||||
UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id);
|
||||
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
|
||||
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
|
||||
|
||||
+183
-868
File diff suppressed because it is too large
Load Diff
+87
-530
@@ -890,26 +890,6 @@ func TestPostChats_ClientType(t *testing.T) {
|
||||
func TestListChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sortedChatIDs := func(ids []uuid.UUID) []uuid.UUID {
|
||||
out := append([]uuid.UUID(nil), ids...)
|
||||
slices.SortFunc(out, func(a, b uuid.UUID) int {
|
||||
return strings.Compare(a.String(), b.String())
|
||||
})
|
||||
return out
|
||||
}
|
||||
requireRootIDs := func(t *testing.T, chats []codersdk.Chat, want ...uuid.UUID) []uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
got := make([]uuid.UUID, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
require.Nil(t, chat.ParentChatID, "list should only return root chats")
|
||||
got = append(got, chat.ID)
|
||||
}
|
||||
|
||||
require.Equal(t, sortedChatIDs(want), sortedChatIDs(got))
|
||||
return got
|
||||
}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1559,171 +1539,6 @@ func TestListChats(t *testing.T) {
|
||||
require.Equal(t, archivedWithPR.ID, chats[0].ID)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("TitleSearch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Verify that the title: filter is wired through the endpoint.
|
||||
// Exhaustive ILIKE behavior is tested in TestGetChatsFilter (Title/* subtests).
|
||||
alpha := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "alpha project",
|
||||
})
|
||||
_ = dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "beta unrelated",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
t.Run("SingleWord", func(t *testing.T) {
|
||||
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "title:alpha"})
|
||||
require.NoError(t, err)
|
||||
requireRootIDs(t, chats, alpha.ID)
|
||||
})
|
||||
|
||||
t.Run("MultiWord", func(t *testing.T) {
|
||||
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: `title:"alpha project"`})
|
||||
require.NoError(t, err)
|
||||
requireRootIDs(t, chats, alpha.ID)
|
||||
})
|
||||
|
||||
t.Run("BareTermsRejected", func(t *testing.T) {
|
||||
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "bare words"})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("PRStatusFilter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Verify that pr_status filter is wired through the endpoint.
|
||||
// Exhaustive query logic is tested in TestGetChatsFilter (PRStatus/* subtests).
|
||||
createChatWithPR := func(title, prURL, prState string, prDraft bool) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: title,
|
||||
Status: database.ChatStatusCompleted,
|
||||
})
|
||||
refreshedAt := time.Now().UTC().Truncate(time.Second)
|
||||
staleAt := refreshedAt.Add(time.Hour)
|
||||
_, err := db.UpsertChatDiffStatusReference(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: prURL, Valid: true},
|
||||
GitBranch: "feature/test",
|
||||
GitRemoteOrigin: "git@github.com:coder/coder.git",
|
||||
StaleAt: staleAt,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UpsertChatDiffStatus(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
database.UpsertChatDiffStatusParams{
|
||||
ChatID: chat.ID,
|
||||
Url: sql.NullString{String: prURL, Valid: true},
|
||||
PullRequestState: sql.NullString{String: prState, Valid: true},
|
||||
PullRequestDraft: prDraft,
|
||||
RefreshedAt: refreshedAt,
|
||||
StaleAt: staleAt,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
draftChat := createChatWithPR("draft pr", "https://github.com/coder/coder/pull/301", "open", true)
|
||||
_ = createChatWithPR("open pr", "https://github.com/coder/coder/pull/302", "open", false)
|
||||
|
||||
t.Run("MatchesDraft", func(t *testing.T) {
|
||||
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:draft"})
|
||||
require.NoError(t, err)
|
||||
requireRootIDs(t, chats, draftChat.ID)
|
||||
})
|
||||
|
||||
t.Run("InvalidPRStatus", func(t *testing.T) {
|
||||
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:bogus"})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UnreadFilter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Verify that has_unread:true filter is wired through the endpoint.
|
||||
// Exhaustive query logic is tested in TestGetChatsFilter (Unread/* subtests).
|
||||
unreadChat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "unread chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
})
|
||||
_, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: unreadChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.UserID},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a second chat with NO unread messages to prove filtering works.
|
||||
_ = dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "read chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
})
|
||||
|
||||
t.Run("MatchesUnread", func(t *testing.T) {
|
||||
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:true"})
|
||||
require.NoError(t, err)
|
||||
requireRootIDs(t, chats, unreadChat.ID)
|
||||
})
|
||||
|
||||
t.Run("InvalidHasUnread", func(t *testing.T) {
|
||||
_, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:bogus"})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatModels(t *testing.T) {
|
||||
@@ -1802,18 +1617,12 @@ func TestListChatModels(t *testing.T) {
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
providerType := database.AiProviderTypeAnthropic
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: string(providerType),
|
||||
CentralAPIKeyEnabled: ptr.Ref(false),
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, string(providerType), "")
|
||||
provider := createAIProviderForTest(t, client, string(providerType), "")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(providerType),
|
||||
AIProviderID: &aiProvider.ID,
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "claude-sonnet",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1833,7 +1642,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.False(t, anthropicProvider.Available)
|
||||
require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason)
|
||||
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
_, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
|
||||
APIKey: "user-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1859,20 +1668,12 @@ func TestListChatModels(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "google",
|
||||
APIKey: "central-api-key",
|
||||
CentralAPIKeyEnabled: ptr.Ref(true),
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
AllowCentralAPIKeyFallback: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "google", "provider-api-key")
|
||||
provider := createAIProviderForTest(t, client, "google", "provider-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "google",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gemini-1.5-pro",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1891,7 +1692,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.NotNil(t, googleProvider)
|
||||
require.True(t, googleProvider.Available)
|
||||
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
_, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
|
||||
APIKey: "user-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1919,17 +1720,12 @@ func TestListChatModels(t *testing.T) {
|
||||
client := newChatClientWithDeploymentValues(t, values)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-key")
|
||||
provider := createAIProviderForTest(t, client, "openai", "test-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1943,7 +1739,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model)
|
||||
|
||||
enabled := false
|
||||
_, err = client.UpdateChatProvider(ctx, chatProvider.ID, codersdk.UpdateChatProviderConfigRequest{
|
||||
_, err = client.UpdateAIProvider(ctx, provider.ID.String(), codersdk.UpdateAIProviderRequest{
|
||||
Enabled: &enabled,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -2450,6 +2246,7 @@ func TestUserAIProviderKeys(t *testing.T) {
|
||||
|
||||
func TestListChatProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -2519,6 +2316,7 @@ func TestListChatProviders(t *testing.T) {
|
||||
|
||||
func TestCreateChatProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -2819,6 +2617,7 @@ func TestCreateChatProvider(t *testing.T) {
|
||||
|
||||
func TestUpdateChatProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -3111,282 +2910,7 @@ func TestUpdateChatProvider(t *testing.T) {
|
||||
|
||||
func TestDeleteChatProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteChatProvider(ctx, provider.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
providers, err := client.ListChatProviders(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, listed := range providers {
|
||||
require.NotEqual(t, provider.ID, listed.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SuccessWithHistoricalChats", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
providerToDelete, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "delete-api-key",
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProviderToDelete := createAIProviderForTest(t, client, providerToDelete.Provider, "delete-api-key")
|
||||
|
||||
deleteContextLimit := int64(4096)
|
||||
deleteIsDefault := true
|
||||
configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: providerToDelete.Provider,
|
||||
AIProviderID: &aiProviderToDelete.ID,
|
||||
Model: "gpt-4o-delete-provider",
|
||||
ContextLimit: &deleteContextLimit,
|
||||
IsDefault: &deleteIsDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
keepProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "anthropic",
|
||||
APIKey: "keep-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keepAIProvider := createAIProviderForTest(t, client, keepProvider.Provider, "keep-api-key")
|
||||
|
||||
keepContextLimit := int64(8192)
|
||||
keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: keepProvider.Provider,
|
||||
AIProviderID: &keepAIProvider.ID,
|
||||
Model: "claude-keep-provider",
|
||||
ContextLimit: &keepContextLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
ModelConfigID: ptr.Ref(configToDelete.ID),
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "provider delete history " + t.Name(),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, configToDelete.ID, chat.LastModelConfigID)
|
||||
|
||||
insertAssistantCostMessage(t, db, chat.ID, configToDelete.ID, 500)
|
||||
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, providerToDelete.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
APIKey: "user-delete-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userKeys, err := db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, userKeys, 1)
|
||||
require.Equal(t, providerToDelete.ID, userKeys[0].ChatProviderID)
|
||||
|
||||
err = client.DeleteChatProvider(ctx, providerToDelete.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), providerToDelete.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
providers, err := client.ListChatProviders(ctx)
|
||||
require.NoError(t, err)
|
||||
foundKeepProvider := false
|
||||
for _, listed := range providers {
|
||||
require.NotEqual(t, providerToDelete.ID, listed.ID)
|
||||
if listed.ID == keepProvider.ID {
|
||||
foundKeepProvider = true
|
||||
}
|
||||
}
|
||||
require.True(t, foundKeepProvider)
|
||||
|
||||
configs, err := client.ListChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
foundDeletedConfig := false
|
||||
foundKeepConfig := false
|
||||
for _, config := range configs {
|
||||
if config.ID == configToDelete.ID {
|
||||
foundDeletedConfig = true
|
||||
}
|
||||
if config.ID == keepConfig.ID {
|
||||
foundKeepConfig = true
|
||||
require.True(t, config.IsDefault)
|
||||
}
|
||||
}
|
||||
require.False(t, foundDeletedConfig)
|
||||
require.True(t, foundKeepConfig)
|
||||
|
||||
defaultConfig, err := db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keepConfig.ID, defaultConfig.ID)
|
||||
|
||||
_, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), configToDelete.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
gotChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat.ID, gotChat.ID)
|
||||
require.Equal(t, configToDelete.ID, gotChat.LastModelConfigID)
|
||||
|
||||
messages, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
foundHistoricalMessage := false
|
||||
for _, message := range messages.Messages {
|
||||
if message.ModelConfigID != nil && *message.ModelConfigID == configToDelete.ID {
|
||||
foundHistoricalMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundHistoricalMessage)
|
||||
|
||||
userKeys, err = db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, userKeys)
|
||||
})
|
||||
|
||||
t.Run("SuccessWithHistoricalChatsAndNoReplacementConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "only-provider-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, provider.Provider, "only-provider-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-only-provider",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
ModelConfigID: ptr.Ref(config.ID),
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "only provider delete history " + t.Name(),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, config.ID, chat.LastModelConfigID)
|
||||
|
||||
insertAssistantCostMessage(t, db, chat.ID, config.ID, 250)
|
||||
|
||||
err = client.DeleteChatProvider(ctx, provider.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
providers, err := client.ListChatProviders(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, listed := range providers {
|
||||
require.NotEqual(t, provider.ID, listed.ID)
|
||||
}
|
||||
|
||||
_, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), provider.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
_, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), config.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
_, err = db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx))
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
configs, err := client.ListChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, configs)
|
||||
|
||||
gotChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, config.ID, gotChat.LastModelConfigID)
|
||||
|
||||
messages, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
foundHistoricalMessage := false
|
||||
for _, message := range messages.Messages {
|
||||
if message.ModelConfigID != nil && *message.ModelConfigID == config.ID {
|
||||
foundHistoricalMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundHistoricalMessage)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
err := client.DeleteChatProvider(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("InvalidProviderID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodDelete,
|
||||
"/api/experimental/chats/providers/not-a-uuid",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat provider ID.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("ForbiddenForOrganizationMember", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = memberClient.DeleteChatProvider(ctx, provider.ID)
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
}
|
||||
|
||||
func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) {
|
||||
@@ -3414,6 +2938,7 @@ func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) {
|
||||
|
||||
func TestUserChatProviderConfigs(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
|
||||
requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig {
|
||||
t.Helper()
|
||||
@@ -3809,6 +3334,7 @@ func TestUserChatProviderConfigs(t *testing.T) {
|
||||
|
||||
func TestUpsertUserChatProviderKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("legacy chat provider API removed in favor of AI provider API")
|
||||
|
||||
t.Run("RejectsTooLargeAPIKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -4263,6 +3789,26 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
requireChatModelPricing(t, configs[0].ModelConfig, pricing)
|
||||
})
|
||||
|
||||
t.Run("UnchangedProviderWithoutAIProviderID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Provider: modelConfig.Provider,
|
||||
Model: "gpt-4o-mini-updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, modelConfig.ID, updated.ID)
|
||||
require.Equal(t, modelConfig.Provider, updated.Provider)
|
||||
require.NotNil(t, updated.AIProviderID)
|
||||
require.Equal(t, *modelConfig.AIProviderID, *updated.AIProviderID)
|
||||
require.Equal(t, "gpt-4o-mini-updated", updated.Model)
|
||||
})
|
||||
|
||||
t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -4482,11 +4028,12 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
missingProviderID := uuid.New()
|
||||
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
AIProviderID: &missingProviderID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
|
||||
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("NotFoundWhenTargetRowDisappearsInTx", func(t *testing.T) {
|
||||
@@ -10953,32 +10500,27 @@ func enableUserChatProviderKey(
|
||||
adminClient *codersdk.ExperimentalClient,
|
||||
userClient *codersdk.ExperimentalClient,
|
||||
providerName string,
|
||||
) codersdk.ChatProviderConfig {
|
||||
) codersdk.AIProvider {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
providers, err := adminClient.ListChatProviders(ctx)
|
||||
providers, err := adminClient.AIProviders(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var provider codersdk.ChatProviderConfig
|
||||
var provider codersdk.AIProvider
|
||||
for _, candidate := range providers {
|
||||
if candidate.Provider == providerName && candidate.Source == codersdk.ChatProviderConfigSourceDatabase {
|
||||
if candidate.Type == codersdk.AIProviderType(providerName) {
|
||||
provider = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEqual(t, uuid.Nil, provider.ID)
|
||||
|
||||
updated, err := adminClient.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = userClient.UpsertUserChatProviderKey(ctx, updated.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
_, err = userClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{
|
||||
APIKey: "test-user-api-key-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return updated
|
||||
return provider
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
@@ -11792,13 +11334,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) {
|
||||
|
||||
defaultModelConfig := createChatModelConfig(t, adminClient)
|
||||
provider := enableUserChatProviderKey(t, adminClient, memberClient, defaultModelConfig.Provider)
|
||||
modelConfig := createAdditionalChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
defaultModelConfig.Provider,
|
||||
"gpt-4o-personal-"+uuid.NewString(),
|
||||
)
|
||||
err := adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{
|
||||
modelProvider := createAIProviderForTest(t, adminClient, "anthropic", "")
|
||||
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", modelProvider.ID, codersdk.CreateUserAIProviderKeyRequest{
|
||||
APIKey: "test-user-api-key-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
contextLimit := int64(4096)
|
||||
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
AIProviderID: &modelProvider.ID,
|
||||
Model: "claude-personal-" + uuid.NewString(),
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{
|
||||
ModelConfigID: modelConfig.ID.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -11813,19 +11362,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) {
|
||||
defaultModelConfig.Provider,
|
||||
"gpt-4o-personal-disabled-"+uuid.NewString(),
|
||||
)
|
||||
disabledProvider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "anthropic",
|
||||
Enabled: ptr.Ref(false),
|
||||
CentralAPIKeyEnabled: ptr.Ref(false),
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
disabledProvider := createAIProviderForTest(t, adminClient, "google", "test-api-key")
|
||||
contextLimit = int64(4096)
|
||||
disabledProviderModelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "google",
|
||||
AIProviderID: &disabledProvider.ID,
|
||||
Model: "gemini-personal-disabled-provider-" + uuid.NewString(),
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
enabled := false
|
||||
disabledProvider, err = adminClient.UpdateAIProvider(ctx, disabledProvider.ID.String(), codersdk.UpdateAIProviderRequest{
|
||||
Enabled: &enabled,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
disabledProviderModelConfig := createAdditionalChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
"anthropic",
|
||||
"claude-personal-disabled-provider-"+uuid.NewString(),
|
||||
)
|
||||
require.NotEqual(t, uuid.Nil, provider.ID)
|
||||
require.NotEqual(t, uuid.Nil, disabledProvider.ID)
|
||||
|
||||
@@ -12153,12 +11703,19 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) {
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
defaultModel := createChatModelConfig(t, adminClient)
|
||||
_ = enableUserChatProviderKey(t, adminClient, adminClient, defaultModel.Provider)
|
||||
overrideModel := createAdditionalChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
defaultModel.Provider,
|
||||
"gpt-4o-root-personal-"+uuid.NewString(),
|
||||
)
|
||||
overrideProvider := createAIProviderForTest(t, adminClient, "anthropic", "")
|
||||
_, err := adminClient.UpsertUserAIProviderKey(ctx, "me", overrideProvider.ID, codersdk.CreateUserAIProviderKeyRequest{
|
||||
APIKey: "test-user-api-key-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
contextLimit := int64(4096)
|
||||
overrideModel, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
AIProviderID: &overrideProvider.ID,
|
||||
Model: "claude-root-personal-" + uuid.NewString(),
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
disabledModel := createDisabledChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
@@ -12203,7 +11760,7 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{
|
||||
err = adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{
|
||||
AllowUsers: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2,6 +2,7 @@ package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -96,15 +97,19 @@ func insertAgentChatTestModelConfig(
|
||||
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "test-openai",
|
||||
DisplayName: sql.NullString{String: "OpenAI", Valid: true},
|
||||
})
|
||||
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: provider.ID,
|
||||
APIKey: "test-api-key",
|
||||
CreatedBy: createdBy,
|
||||
})
|
||||
|
||||
return dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
IsDefault: true,
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
|
||||
// advisorOverrideStubStore stubs only the database methods that
|
||||
// resolveAdvisorModelOverride exercises. The prod code calls
|
||||
// GetEnabledChatModelConfigByID so the query joins chat_providers and
|
||||
// filters both enabled flags atomically; tests simulate that by returning
|
||||
// GetEnabledChatModelConfigByID so the query joins ai_providers and
|
||||
// filters both enabled flags atomically. Tests simulate that by returning
|
||||
// configs the stub treats as enabled.
|
||||
type advisorOverrideStubStore struct {
|
||||
database.Store
|
||||
|
||||
+66
-48
@@ -350,13 +350,8 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
return fallbackModel, fallbackCallConfig
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigByID checks the model config and referenced
|
||||
// provider enabled state, so it returns sql.ErrNoRows the moment an
|
||||
// admin disables either one. Using the cached ModelConfigByID here
|
||||
// would keep resolving an override whose provider was just disabled,
|
||||
// and an available fallback key would let ModelFromConfig succeed,
|
||||
// silently routing advisor prompts to a provider the admin expects to
|
||||
// be off.
|
||||
// Re-read the override instead of using the cache so disabled models
|
||||
// or providers stop routing advisor prompts immediately.
|
||||
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
|
||||
ctx,
|
||||
advisorCfg.ModelConfigID,
|
||||
@@ -7102,9 +7097,8 @@ func (p *Server) runChat(
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
// even after the chat processing context is canceled.
|
||||
// Snapshot model, provider keys, logger, and ctx before launch; all four get
|
||||
// reassigned below (model = cuModel, providerKeys = computerUseProviderKeys,
|
||||
// logger = logger.With(...), ctx = runCtx) and the goroutine captures by reference.
|
||||
// Snapshot values captured by the goroutine because model, providerKeys,
|
||||
// logger, and ctx are reassigned below.
|
||||
titleModel := model
|
||||
titleProviderKeys := providerKeys
|
||||
titleLogger := logger
|
||||
@@ -8502,13 +8496,17 @@ func (p *Server) resolveChatModel(
|
||||
}
|
||||
|
||||
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
|
||||
if !provider.Enabled {
|
||||
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
||||
}
|
||||
keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err)
|
||||
}
|
||||
return p.aiProviderConfigFromKeys(provider, keys)
|
||||
}
|
||||
|
||||
func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) {
|
||||
if !provider.Enabled {
|
||||
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
||||
}
|
||||
apiKey := ""
|
||||
// GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes
|
||||
// one provider-scoped key because runtime provider config has one API key slot.
|
||||
@@ -8529,6 +8527,48 @@ func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvi
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
providerIDs := make([]uuid.UUID, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerIDs = append(providerIDs, provider.ID)
|
||||
}
|
||||
keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get AI provider keys: %w", err)
|
||||
}
|
||||
keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers))
|
||||
for _, key := range keys {
|
||||
keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key)
|
||||
}
|
||||
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configuredProviders = append(configuredProviders, configuredProvider)
|
||||
}
|
||||
return configuredProviders, nil
|
||||
}
|
||||
|
||||
func ensureUniqueConfiguredProviderTypes(providers []chatprovider.ConfiguredProvider) error {
|
||||
seen := make(map[string]uuid.UUID, len(providers))
|
||||
for _, provider := range providers {
|
||||
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if existingProviderID, ok := seen[normalizedProvider]; ok && existingProviderID != provider.ProviderID {
|
||||
return xerrors.Errorf("multiple enabled AI providers use provider type %q; select an AI provider by ID", normalizedProvider)
|
||||
}
|
||||
seen[normalizedProvider] = provider.ProviderID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveUserProviderAPIKeysForProvider(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
@@ -8559,12 +8599,6 @@ func (p *Server) resolveUserProviderAPIKeysForProvider(
|
||||
[]chatprovider.ConfiguredProvider{configuredProvider},
|
||||
userKeys,
|
||||
)
|
||||
enabledProviders := map[string]struct{}{}
|
||||
normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider)
|
||||
if normalizedProvider != "" {
|
||||
enabledProviders[normalizedProvider] = struct{}{}
|
||||
}
|
||||
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
@@ -8598,9 +8632,6 @@ func (p *Server) resolveUserProviderAPIKeys(
|
||||
ownerID uuid.UUID,
|
||||
selectedAIProviderID uuid.UUID,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
var configuredProviders []chatprovider.ConfiguredProvider
|
||||
userKeys := []chatprovider.UserProviderKey{}
|
||||
|
||||
if selectedAIProviderID != uuid.Nil {
|
||||
provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID)
|
||||
if err != nil {
|
||||
@@ -8608,48 +8639,35 @@ func (p *Server) resolveUserProviderAPIKeys(
|
||||
}
|
||||
return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
|
||||
}
|
||||
|
||||
providers, err := p.configCache.EnabledProviders(ctx)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"get enabled chat providers: %w",
|
||||
"get enabled AI providers: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
configuredProviders = make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
||||
)
|
||||
for _, provider := range providers {
|
||||
configuredProviders = append(
|
||||
configuredProviders, chatprovider.ConfiguredProvider{
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
},
|
||||
)
|
||||
configuredProviders, err := p.aiProviderConfigs(ctx, providers)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, err
|
||||
}
|
||||
allowAnyUserAPIKey := false
|
||||
for _, provider := range configuredProviders {
|
||||
if provider.AllowUserAPIKey {
|
||||
allowAnyUserAPIKey = true
|
||||
break
|
||||
if err := ensureUniqueConfiguredProviderTypes(configuredProviders); err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, err
|
||||
}
|
||||
}
|
||||
if allowAnyUserAPIKey {
|
||||
userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID)
|
||||
|
||||
userKeys := []chatprovider.UserProviderKey{}
|
||||
if p.allowBYOK {
|
||||
userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"get user chat provider keys: %w",
|
||||
"get user AI provider keys: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
||||
for _, userKey := range userKeyRows {
|
||||
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
||||
ChatProviderID: userKey.ChatProviderID,
|
||||
ChatProviderID: userKey.AIProviderID,
|
||||
APIKey: userKey.APIKey,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -797,12 +797,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
providerID := uuid.New()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
gomock.Any(),
|
||||
@@ -961,12 +963,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
providerID := uuid.New()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
gomock.Any(),
|
||||
@@ -1110,11 +1114,13 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowCentralApiKeyFallback: true,
|
||||
providerID := uuid.New()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeAnthropic,
|
||||
Enabled: true,
|
||||
}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
require.NoError(t, err)
|
||||
@@ -1185,10 +1191,13 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
providerID := uuid.New()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
require.NoError(t, err)
|
||||
@@ -3922,7 +3931,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
|
||||
database.ChatModelConfig{}, xerrors.New("no model configured"),
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
||||
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
||||
@@ -5696,7 +5705,7 @@ func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) {
|
||||
return database.ChatModelConfig{}, chatloop.ErrInterrupted
|
||||
},
|
||||
).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
||||
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
||||
|
||||
@@ -6120,21 +6120,24 @@ func setOpenAIProviderBaseURL(
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
provider, err := db.GetChatProviderByProvider(ctx, "openai")
|
||||
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
for _, provider := range providers {
|
||||
if provider.Type != database.AiProviderTypeOpenai {
|
||||
continue
|
||||
}
|
||||
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
BaseUrl: baseURL,
|
||||
Settings: provider.Settings,
|
||||
SettingsKeyID: provider.SettingsKeyID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
require.Fail(t, "openai provider not found")
|
||||
}
|
||||
|
||||
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
||||
@@ -7193,9 +7196,10 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) {
|
||||
true,
|
||||
false,
|
||||
)
|
||||
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
AIProviderID: provider.ID,
|
||||
APIKey: userAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
)
|
||||
|
||||
type cachedProviders struct {
|
||||
providers []database.ChatProvider
|
||||
providers []database.AIProvider
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ type chatConfigCache struct {
|
||||
// Providers (singleton).
|
||||
providers *cachedProviders
|
||||
providerGeneration uint64
|
||||
providerFetches singleflight.Group[string, []database.ChatProvider]
|
||||
providerFetches singleflight.Group[string, []database.AIProvider]
|
||||
|
||||
// Model configs (keyed by ID).
|
||||
modelTopologyEpoch uint64
|
||||
@@ -131,7 +131,7 @@ func singleflightDoChan[K comparable, V any](
|
||||
}
|
||||
}
|
||||
|
||||
func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.AIProvider, error) {
|
||||
if providers, ok := c.cachedProviders(); ok {
|
||||
return providers, nil
|
||||
}
|
||||
@@ -141,12 +141,12 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat
|
||||
ctx,
|
||||
&c.providerFetches,
|
||||
fmt.Sprintf("%d:providers", generation),
|
||||
func() ([]database.ChatProvider, error) {
|
||||
func() ([]database.AIProvider, error) {
|
||||
if cached, ok := c.cachedProviders(); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
fetched, err := c.db.GetEnabledChatProviders(c.ctx)
|
||||
fetched, err := c.db.GetAIProviders(c.ctx, database.GetAIProvidersParams{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -161,7 +161,7 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat
|
||||
return slices.Clone(providers), nil
|
||||
}
|
||||
|
||||
func (c *chatConfigCache) cachedProviders() ([]database.ChatProvider, bool) {
|
||||
func (c *chatConfigCache) cachedProviders() ([]database.AIProvider, bool) {
|
||||
c.mu.RLock()
|
||||
entry := c.providers
|
||||
c.mu.RUnlock()
|
||||
@@ -188,7 +188,7 @@ func (c *chatConfigCache) providersGeneration() uint64 {
|
||||
return generation
|
||||
}
|
||||
|
||||
func (c *chatConfigCache) storeProviders(generation uint64, providers []database.ChatProvider) {
|
||||
func (c *chatConfigCache) storeProviders(generation uint64, providers []database.AIProvider) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
type stubChatConfigStore struct {
|
||||
database.Store
|
||||
|
||||
getEnabledChatProviders func(context.Context) ([]database.ChatProvider, error)
|
||||
getAIProviders func(context.Context) ([]database.AIProvider, error)
|
||||
getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
|
||||
getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error)
|
||||
getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error)
|
||||
@@ -35,12 +35,12 @@ type stubChatConfigStore struct {
|
||||
advisorConfigCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *stubChatConfigStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) {
|
||||
s.enabledProvidersCalls.Add(1)
|
||||
if s.getEnabledChatProviders == nil {
|
||||
panic("unexpected GetEnabledChatProviders call")
|
||||
if s.getAIProviders == nil {
|
||||
panic("unexpected GetAIProviders call")
|
||||
}
|
||||
return s.getEnabledChatProviders(ctx)
|
||||
return s.getAIProviders(ctx)
|
||||
}
|
||||
|
||||
func (s *stubChatConfigStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) {
|
||||
@@ -80,9 +80,9 @@ func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
providers := []database.ChatProvider{testChatProvider("provider-a")}
|
||||
providers := []database.AIProvider{testAIProvider("provider-a")}
|
||||
store := &stubChatConfigStore{
|
||||
getEnabledChatProviders: func(context.Context) ([]database.ChatProvider, error) {
|
||||
getAIProviders: func(context.Context) ([]database.AIProvider, error) {
|
||||
return providers, nil
|
||||
},
|
||||
}
|
||||
@@ -104,9 +104,9 @@ func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
store := &stubChatConfigStore{}
|
||||
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
|
||||
call := store.enabledProvidersCalls.Load()
|
||||
return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil
|
||||
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
|
||||
}
|
||||
cache := newChatConfigCache(ctx, store, clock)
|
||||
|
||||
@@ -126,9 +126,9 @@ func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
store := &stubChatConfigStore{}
|
||||
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
|
||||
call := store.enabledProvidersCalls.Load()
|
||||
return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil
|
||||
return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil
|
||||
}
|
||||
cache := newChatConfigCache(ctx, store, clock)
|
||||
|
||||
@@ -398,12 +398,12 @@ func TestConfigCache_Singleflight(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
clock := quartz.NewMock(t)
|
||||
providers := []database.ChatProvider{testChatProvider("provider-a")}
|
||||
providers := []database.AIProvider{testAIProvider("provider-a")}
|
||||
fetchStarted := make(chan struct{})
|
||||
releaseFetch := make(chan struct{})
|
||||
var startedOnce sync.Once
|
||||
store := &stubChatConfigStore{}
|
||||
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
|
||||
startedOnce.Do(func() { close(fetchStarted) })
|
||||
<-releaseFetch
|
||||
return providers, nil
|
||||
@@ -411,7 +411,7 @@ func TestConfigCache_Singleflight(t *testing.T) {
|
||||
cache := newChatConfigCache(ctx, store, clock)
|
||||
|
||||
const callers = 8
|
||||
results := make([][]database.ChatProvider, callers)
|
||||
results := make([][]database.AIProvider, callers)
|
||||
errs := make([]error, callers)
|
||||
var wg sync.WaitGroup
|
||||
start := make(chan struct{})
|
||||
@@ -441,13 +441,13 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
clock := quartz.NewMock(t)
|
||||
firstProviders := []database.ChatProvider{testChatProvider("provider-a")}
|
||||
secondProviders := []database.ChatProvider{testChatProvider("provider-b")}
|
||||
firstProviders := []database.AIProvider{testAIProvider("provider-a")}
|
||||
secondProviders := []database.AIProvider{testAIProvider("provider-b")}
|
||||
fetchStarted := make(chan struct{})
|
||||
releaseFetch := make(chan struct{})
|
||||
var startedOnce sync.Once
|
||||
store := &stubChatConfigStore{}
|
||||
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
|
||||
call := store.enabledProvidersCalls.Load()
|
||||
if call == 1 {
|
||||
startedOnce.Do(func() { close(fetchStarted) })
|
||||
@@ -458,7 +458,7 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) {
|
||||
}
|
||||
cache := newChatConfigCache(ctx, store, clock)
|
||||
|
||||
resultCh := make(chan []database.ChatProvider, 1)
|
||||
resultCh := make(chan []database.AIProvider, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
providers, err := cache.EnabledProviders(ctx)
|
||||
@@ -494,14 +494,14 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
clock := quartz.NewMock(t)
|
||||
staleProviders := []database.ChatProvider{testChatProvider("provider-stale")}
|
||||
freshProviders := []database.ChatProvider{testChatProvider("provider-fresh")}
|
||||
staleProviders := []database.AIProvider{testAIProvider("provider-stale")}
|
||||
freshProviders := []database.AIProvider{testAIProvider("provider-fresh")}
|
||||
firstStarted := make(chan struct{})
|
||||
secondStarted := make(chan struct{})
|
||||
releaseFirst := make(chan struct{})
|
||||
releaseSecond := make(chan struct{})
|
||||
store := &stubChatConfigStore{}
|
||||
store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(context.Context) ([]database.AIProvider, error) {
|
||||
switch call := store.enabledProvidersCalls.Load(); call {
|
||||
case 1:
|
||||
close(firstStarted)
|
||||
@@ -518,7 +518,7 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing
|
||||
cache := newChatConfigCache(ctx, store, clock)
|
||||
|
||||
type result struct {
|
||||
providers []database.ChatProvider
|
||||
providers []database.AIProvider
|
||||
err error
|
||||
}
|
||||
|
||||
@@ -670,11 +670,12 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testi
|
||||
require.Equal(t, int32(2), store.modelConfigByIDCalls.Load())
|
||||
}
|
||||
|
||||
func testChatProvider(name string) database.ChatProvider {
|
||||
return database.ChatProvider{
|
||||
func testAIProvider(name string) database.AIProvider {
|
||||
return database.AIProvider{
|
||||
ID: uuid.New(),
|
||||
Provider: name,
|
||||
DisplayName: name,
|
||||
Type: database.AIProviderType(name),
|
||||
Name: name,
|
||||
DisplayName: sql.NullString{String: name, Valid: true},
|
||||
Enabled: true,
|
||||
CreatedAt: time.Unix(0, 0).UTC(),
|
||||
UpdatedAt: time.Unix(0, 0).UTC(),
|
||||
@@ -737,19 +738,19 @@ func TestConfigCache_CallerCancellation(t *testing.T) {
|
||||
name: "EnabledProviders",
|
||||
setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) {
|
||||
var once sync.Once
|
||||
store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
|
||||
once.Do(func() { close(started) })
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-release:
|
||||
return []database.ChatProvider{testChatProvider("p")}, nil
|
||||
return []database.AIProvider{testAIProvider("p")}, nil
|
||||
}
|
||||
}
|
||||
},
|
||||
setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) {
|
||||
var once sync.Once
|
||||
store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) {
|
||||
once.Do(func() { close(started) })
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
|
||||
@@ -154,12 +154,12 @@ func validateModelConfigAndResolveProvider(
|
||||
}
|
||||
|
||||
func enabledProviderContainsName(
|
||||
providers []database.ChatProvider,
|
||||
providers []database.AIProvider,
|
||||
providerName string,
|
||||
) bool {
|
||||
normalizedProviderName := chatprovider.NormalizeProvider(providerName)
|
||||
for _, provider := range providers {
|
||||
if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName {
|
||||
if chatprovider.NormalizeProvider(string(provider.Type)) == normalizedProviderName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -183,12 +184,14 @@ func seedInternalChatDeps(
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
})
|
||||
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
IsDefault: true,
|
||||
})
|
||||
|
||||
@@ -309,24 +312,38 @@ func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) {
|
||||
require.True(t, keys.HasProvider("bedrock"))
|
||||
require.Empty(t, keys.APIKey("bedrock"))
|
||||
})
|
||||
|
||||
t.Run("RejectsAmbiguousProviderTypeWithoutSelectedProvider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
ctx := chatdTestContext(t)
|
||||
user, _, _ := seedInternalChatDeps(t, db)
|
||||
insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "first-provider-api-key", true)
|
||||
insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "second-provider-api-key", true)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, uuid.Nil)
|
||||
require.ErrorContains(t, err, "multiple enabled AI providers use provider type")
|
||||
require.Equal(t, chatprovider.ProviderAPIKeys{}, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveChatModel_AIProviderDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
user, org, _ := seedInternalChatDeps(t, db)
|
||||
provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", false)
|
||||
modelConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
AIProviderID: uuid.NullUUID{
|
||||
UUID: provider.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
_, err := sqlDB.ExecContext(ctx, "UPDATE chat_model_configs SET ai_provider_id = $1 WHERE id = $2", provider.ID, modelConfig.ID)
|
||||
require.NoError(t, err)
|
||||
loadedModelConfig, err := db.GetChatModelConfigByID(ctx, modelConfig.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, loadedModelConfig.AIProviderID.Valid)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
@@ -408,19 +425,20 @@ func insertInternalChatProvider(
|
||||
centralAPIKeyEnabled bool,
|
||||
allowUserAPIKey bool,
|
||||
allowCentralAPIKeyFallback bool,
|
||||
) database.ChatProvider {
|
||||
) database.AIProvider {
|
||||
t.Helper()
|
||||
|
||||
providerConfig := dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
}, func(p *database.InsertChatProviderParams) {
|
||||
p.APIKey = apiKey
|
||||
p.CentralApiKeyEnabled = centralAPIKeyEnabled
|
||||
p.AllowUserApiKey = allowUserAPIKey
|
||||
p.AllowCentralApiKeyFallback = allowCentralAPIKeyFallback
|
||||
providerConfig := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AIProviderType(provider),
|
||||
Name: "test-" + uuid.NewString(),
|
||||
DisplayName: sql.NullString{String: provider, Valid: true},
|
||||
})
|
||||
if apiKey != "" {
|
||||
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: providerConfig.ID,
|
||||
APIKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
return providerConfig
|
||||
}
|
||||
|
||||
@@ -348,7 +348,8 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{uuid.Nil}).Return(nil, nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
@@ -498,7 +499,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T)
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, _, err := server.resolveManualTitleModel(
|
||||
@@ -524,7 +526,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, _, err := server.resolveManualTitleModel(
|
||||
|
||||
@@ -33,16 +33,15 @@ func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) {
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
@@ -102,16 +101,15 @@ func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) {
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
|
||||
@@ -944,7 +944,7 @@ func TestWorker(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
|
||||
// 3. Set up FK chain: ai_providers -> chat_model_configs -> chats.
|
||||
_ = dbgen.ChatProvider(t, db, database.ChatProvider{})
|
||||
|
||||
modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
|
||||
@@ -122,11 +122,17 @@ func seedChatDependencies(
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "test-" + uuid.NewString(),
|
||||
BaseUrl: safetyNet.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
})
|
||||
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
||||
ProviderID: provider.ID,
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
IsDefault: true,
|
||||
@@ -186,21 +192,24 @@ func setOpenAIProviderBaseURL(
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
provider, err := db.GetChatProviderByProvider(ctx, "openai")
|
||||
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
for _, provider := range providers {
|
||||
if provider.Type != database.AiProviderTypeOpenai {
|
||||
continue
|
||||
}
|
||||
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowUserApiKey: false,
|
||||
AllowCentralApiKeyFallback: false,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
BaseUrl: baseURL,
|
||||
Settings: provider.Settings,
|
||||
SettingsKeyID: provider.SettingsKeyID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
require.Fail(t, "openai provider not found")
|
||||
}
|
||||
|
||||
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
|
||||
|
||||
@@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if strings.TrimSpace(userProviderKey.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
|
||||
@@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if strings.TrimSpace(provider.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get ai providers: %w", err)
|
||||
@@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if !userProviderKey.ApiKeyKeyID.Valid {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
}
|
||||
|
||||
userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user secrets for user %s: %w", uid, err)
|
||||
@@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
providers, err := cryptDB.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat providers: %w", err)
|
||||
}
|
||||
log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers)))
|
||||
for idx, provider := range providers {
|
||||
if !provider.ApiKeyKeyID.Valid {
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get ai providers: %w", err)
|
||||
@@ -475,16 +378,10 @@ DELETE FROM user_links
|
||||
DELETE FROM external_auth_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL;
|
||||
DELETE FROM user_chat_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
DELETE FROM user_ai_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
DELETE FROM user_secrets
|
||||
WHERE value_key_id IS NOT NULL;
|
||||
UPDATE chat_providers
|
||||
SET api_key = '',
|
||||
api_key_key_id = NULL
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
UPDATE ai_providers
|
||||
SET settings = NULL,
|
||||
settings_key_id = NULL
|
||||
@@ -502,9 +399,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error {
|
||||
store := database.New(sqlDB)
|
||||
_, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err)
|
||||
return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err)
|
||||
}
|
||||
log.Info(ctx, "deleted encrypted user tokens and chat provider API keys")
|
||||
log.Info(ctx, "deleted encrypted user tokens and AI provider API keys")
|
||||
|
||||
log.Info(ctx, "revoking all active keys")
|
||||
keys, err := store.GetDBCryptKeys(ctx)
|
||||
|
||||
+13
-137
@@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := db.decryptAIProviderKey(&keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
@@ -576,92 +589,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
|
||||
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
||||
providers, err := db.Store.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range providers {
|
||||
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.InsertChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
|
||||
provider, err := db.Store.UpdateChatProvider(ctx, params)
|
||||
if err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error {
|
||||
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
||||
}
|
||||
@@ -754,57 +681,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
|
||||
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// decryptMCPServerConfig decrypts all encrypted fields on a
|
||||
// single MCPServerConfig in place.
|
||||
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
||||
|
||||
@@ -1281,6 +1281,18 @@ func TestAIProviderKeys(t *testing.T) {
|
||||
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
|
||||
})
|
||||
|
||||
t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey)
|
||||
requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey)
|
||||
})
|
||||
|
||||
t.Run("DeleteAIProviderKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
@@ -1558,113 +1570,6 @@ func TestMCPServerUserTokens(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserChatProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
//nolint:gosec // test credentials
|
||||
initialAPIKey = "sk-initial-api-key-value"
|
||||
//nolint:gosec // test credentials
|
||||
updatedAPIKey = "sk-updated-api-key-value"
|
||||
)
|
||||
|
||||
insertProviderAndKey := func(
|
||||
t *testing.T,
|
||||
crypt *dbCrypt,
|
||||
ciphers []Cipher,
|
||||
) (database.ChatProvider, database.UserChatProviderKey) {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{
|
||||
AllowUserApiKey: true,
|
||||
}, func(params *database.InsertChatProviderParams) {
|
||||
params.APIKey = ""
|
||||
})
|
||||
|
||||
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: initialAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialAPIKey, key.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
|
||||
return provider, key
|
||||
}
|
||||
|
||||
getUserChatProviderKey := func(t *testing.T, store interface {
|
||||
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
|
||||
}, userID uuid.UUID, providerID uuid.UUID,
|
||||
) database.UserChatProviderKey {
|
||||
t.Helper()
|
||||
keys, err := store.GetUserChatProviderKeys(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, providerID, keys[0].ChatProviderID)
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, key.ID, got.ID)
|
||||
require.Equal(t, initialAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
|
||||
})
|
||||
|
||||
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, crypt, ciphers := setup(t)
|
||||
_, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, key.ID, keys[0].ID)
|
||||
require.Equal(t, initialAPIKey, keys[0].APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
|
||||
})
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: key.UserID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: updatedAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.ID, updated.ID)
|
||||
require.Equal(t, key.CreatedAt, updated.CreatedAt)
|
||||
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
|
||||
require.Equal(t, updatedAPIKey, updated.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, updatedAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, updatedAPIKey, keys[0].APIKey)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -60,10 +60,15 @@ else
|
||||
echo "Base ref $base not found locally, fetching $ref..."
|
||||
git fetch origin "$ref" --depth=1 2>/dev/null || true
|
||||
if ! git rev-parse --verify "$base" >/dev/null 2>&1; then
|
||||
if git rev-parse --verify origin/main >/dev/null 2>&1; then
|
||||
echo "WARNING: could not fetch base ref $base, falling back to origin/main merge base."
|
||||
base=$(git merge-base HEAD origin/main 2>/dev/null || echo "origin/main")
|
||||
else
|
||||
echo "ERROR: could not fetch base ref $base."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
found=0
|
||||
if ! diff_output=$(git diff "$base" -U0 -- . "${exclude_pathspecs[@]}" 2>&1); then
|
||||
|
||||
@@ -284,11 +284,6 @@ describe("api.ts", () => {
|
||||
providers: [],
|
||||
},
|
||||
],
|
||||
[
|
||||
"/api/experimental/chats/providers",
|
||||
() => API.experimental.getChatProviderConfigs(),
|
||||
[],
|
||||
],
|
||||
[
|
||||
"/api/experimental/chats/model-configs",
|
||||
() => API.experimental.getChatModelConfigs(),
|
||||
@@ -310,10 +305,6 @@ describe("api.ts", () => {
|
||||
"/api/experimental/chats/models",
|
||||
() => API.experimental.getChatModels(),
|
||||
],
|
||||
[
|
||||
"/api/experimental/chats/providers",
|
||||
() => API.experimental.getChatProviderConfigs(),
|
||||
],
|
||||
[
|
||||
"/api/experimental/chats/model-configs",
|
||||
() => API.experimental.getChatModelConfigs(),
|
||||
|
||||
@@ -407,11 +407,8 @@ export type DeploymentConfig = Readonly<{
|
||||
options: TypesGen.SerpentOption[];
|
||||
}>;
|
||||
|
||||
const chatProviderConfigsPath = "/api/experimental/chats/providers";
|
||||
const aiProviderConfigsPath = "/api/v2/ai/providers";
|
||||
const chatModelConfigsPath = "/api/experimental/chats/model-configs";
|
||||
const userChatProviderConfigsPath =
|
||||
"/api/experimental/chats/user-provider-configs";
|
||||
const userSkillsPath = (user: string) =>
|
||||
`/api/experimental/users/${encodeURIComponent(user)}/skills`;
|
||||
const userSkillPath = (user: string, name: string) =>
|
||||
@@ -3731,42 +3728,6 @@ class ExperimentalApiMethods {
|
||||
);
|
||||
};
|
||||
|
||||
getChatProviderConfigs = async (): Promise<TypesGen.ChatProviderConfig[]> => {
|
||||
const response = await this.axios.get<TypesGen.ChatProviderConfig[]>(
|
||||
chatProviderConfigsPath,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
createChatProviderConfig = async (
|
||||
req: TypesGen.CreateChatProviderConfigRequest,
|
||||
): Promise<TypesGen.ChatProviderConfig> => {
|
||||
const response = await this.axios.post<TypesGen.ChatProviderConfig>(
|
||||
chatProviderConfigsPath,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
updateChatProviderConfig = async (
|
||||
providerConfigId: string,
|
||||
req: TypesGen.UpdateChatProviderConfigRequest,
|
||||
): Promise<TypesGen.ChatProviderConfig> => {
|
||||
const response = await this.axios.patch<TypesGen.ChatProviderConfig>(
|
||||
`${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
deleteChatProviderConfig = async (
|
||||
providerConfigId: string,
|
||||
): Promise<void> => {
|
||||
await this.axios.delete(
|
||||
`${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
|
||||
);
|
||||
};
|
||||
|
||||
getChatModelConfigs = async (): Promise<TypesGen.ChatModelConfig[]> => {
|
||||
const response =
|
||||
await this.axios.get<TypesGen.ChatModelConfig[]>(chatModelConfigsPath);
|
||||
@@ -3800,34 +3761,6 @@ class ExperimentalApiMethods {
|
||||
);
|
||||
};
|
||||
|
||||
getUserChatProviderConfigs = async (): Promise<
|
||||
TypesGen.UserChatProviderConfig[]
|
||||
> => {
|
||||
const response = await this.axios.get<TypesGen.UserChatProviderConfig[]>(
|
||||
userChatProviderConfigsPath,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
upsertUserChatProviderKey = async (
|
||||
providerConfigId: string,
|
||||
req: TypesGen.CreateUserChatProviderKeyRequest,
|
||||
): Promise<TypesGen.UserChatProviderConfig> => {
|
||||
const response = await this.axios.put<TypesGen.UserChatProviderConfig>(
|
||||
`${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
deleteUserChatProviderKey = async (
|
||||
providerConfigId: string,
|
||||
): Promise<void> => {
|
||||
await this.axios.delete(
|
||||
`${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`,
|
||||
);
|
||||
};
|
||||
|
||||
getMCPServerConfigs = async (): Promise<TypesGen.MCPServerConfig[]> => {
|
||||
const response =
|
||||
await this.axios.get<TypesGen.MCPServerConfig[]>(mcpServerConfigsPath);
|
||||
|
||||
+53
-38
@@ -39,12 +39,36 @@ const createProviderConfig = (
|
||||
updated_at: overrides.updated_at ?? now,
|
||||
});
|
||||
|
||||
const createProviderKey = (providerId: string): TypesGen.AIProviderKey => ({
|
||||
id: `key-${providerId}`,
|
||||
masked: "sk-...test",
|
||||
created_at: now,
|
||||
});
|
||||
|
||||
const toAIProvider = (
|
||||
providerConfig: TypesGen.ChatProviderConfig,
|
||||
): TypesGen.AIProvider => ({
|
||||
id: providerConfig.id,
|
||||
type: providerConfig.provider as TypesGen.AIProviderType,
|
||||
name: providerConfig.provider,
|
||||
display_name: providerConfig.display_name,
|
||||
enabled: providerConfig.enabled,
|
||||
base_url: providerConfig.base_url ?? "",
|
||||
api_keys: providerConfig.has_api_key
|
||||
? [createProviderKey(providerConfig.id)]
|
||||
: [],
|
||||
settings: {},
|
||||
created_at: providerConfig.created_at ?? now,
|
||||
updated_at: providerConfig.updated_at ?? now,
|
||||
});
|
||||
|
||||
const createModelConfig = (
|
||||
overrides: Partial<TypesGen.ChatModelConfig> &
|
||||
Pick<TypesGen.ChatModelConfig, "id" | "provider" | "model">,
|
||||
): TypesGen.ChatModelConfig => ({
|
||||
id: overrides.id,
|
||||
provider: overrides.provider,
|
||||
ai_provider_id: overrides.ai_provider_id,
|
||||
model: overrides.model,
|
||||
display_name: overrides.display_name ?? overrides.model,
|
||||
enabled: overrides.enabled ?? true,
|
||||
@@ -68,11 +92,9 @@ const setupChatSpies = (state: {
|
||||
modelConfigs: TypesGen.ChatModelConfig[];
|
||||
modelCatalog: TypesGen.ChatModelsResponse;
|
||||
}) => {
|
||||
spyOn(API.experimental, "getChatProviderConfigs").mockImplementation(
|
||||
async () => {
|
||||
return state.providerConfigs;
|
||||
},
|
||||
);
|
||||
spyOn(API.experimental, "listAIProviders").mockImplementation(async () => {
|
||||
return state.providerConfigs.map(toAIProvider);
|
||||
});
|
||||
spyOn(API.experimental, "getChatModelConfigs").mockImplementation(
|
||||
async () => {
|
||||
return state.modelConfigs;
|
||||
@@ -82,31 +104,26 @@ const setupChatSpies = (state: {
|
||||
return state.modelCatalog;
|
||||
});
|
||||
|
||||
spyOn(API.experimental, "createChatProviderConfig").mockImplementation(
|
||||
spyOn(API.experimental, "createAIProvider").mockImplementation(
|
||||
async (req) => {
|
||||
const created = createProviderConfig({
|
||||
id: `provider-${Date.now()}`,
|
||||
provider: req.provider ?? "",
|
||||
provider: req.type ?? "openai",
|
||||
display_name: req.display_name ?? "",
|
||||
has_api_key: (req.api_key ?? "").trim().length > 0,
|
||||
central_api_key_enabled: req.central_api_key_enabled ?? true,
|
||||
allow_user_api_key: req.allow_user_api_key ?? false,
|
||||
allow_central_api_key_fallback:
|
||||
req.allow_central_api_key_fallback ?? false,
|
||||
has_api_key:
|
||||
req.api_keys?.some((apiKey) => apiKey.trim().length > 0) ?? false,
|
||||
base_url: req.base_url ?? "",
|
||||
enabled: req.enabled ?? true,
|
||||
source: "database",
|
||||
});
|
||||
state.providerConfigs = [
|
||||
...state.providerConfigs.filter(
|
||||
(p) => !(p.id === nilProviderConfigID && p.provider === req.provider),
|
||||
),
|
||||
...state.providerConfigs.filter((p) => p.id !== created.id),
|
||||
created,
|
||||
];
|
||||
return created;
|
||||
return toAIProvider(created);
|
||||
},
|
||||
);
|
||||
|
||||
spyOn(API.experimental, "updateChatProviderConfig").mockImplementation(
|
||||
spyOn(API.experimental, "updateAIProvider").mockImplementation(
|
||||
async (providerConfigId, req) => {
|
||||
const idx = state.providerConfigs.findIndex(
|
||||
(p) => p.id === providerConfigId,
|
||||
@@ -122,29 +139,30 @@ const setupChatSpies = (state: {
|
||||
? req.display_name
|
||||
: current.display_name,
|
||||
has_api_key:
|
||||
typeof req.api_key === "string"
|
||||
? req.api_key.trim().length > 0
|
||||
: current.has_api_key,
|
||||
central_api_key_enabled:
|
||||
typeof req.central_api_key_enabled === "boolean"
|
||||
? req.central_api_key_enabled
|
||||
: current.central_api_key_enabled,
|
||||
allow_user_api_key:
|
||||
typeof req.allow_user_api_key === "boolean"
|
||||
? req.allow_user_api_key
|
||||
: current.allow_user_api_key,
|
||||
allow_central_api_key_fallback:
|
||||
typeof req.allow_central_api_key_fallback === "boolean"
|
||||
? req.allow_central_api_key_fallback
|
||||
: current.allow_central_api_key_fallback,
|
||||
req.api_keys === undefined
|
||||
? current.has_api_key
|
||||
: req.api_keys.some((apiKey) =>
|
||||
apiKey.api_key !== undefined
|
||||
? apiKey.api_key.trim().length > 0
|
||||
: apiKey.id !== undefined,
|
||||
),
|
||||
base_url:
|
||||
typeof req.base_url === "string" ? req.base_url : current.base_url,
|
||||
enabled:
|
||||
typeof req.enabled === "boolean" ? req.enabled : current.enabled,
|
||||
updated_at: now,
|
||||
};
|
||||
state.providerConfigs = state.providerConfigs.map((p, i) =>
|
||||
i === idx ? updated : p,
|
||||
);
|
||||
return updated;
|
||||
return toAIProvider(updated);
|
||||
},
|
||||
);
|
||||
spyOn(API.experimental, "deleteAIProvider").mockImplementation(
|
||||
async (providerConfigId) => {
|
||||
state.providerConfigs = state.providerConfigs.filter(
|
||||
(p) => p.id !== providerConfigId,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
@@ -154,6 +172,7 @@ const setupChatSpies = (state: {
|
||||
id: `model-${state.modelConfigs.length + 1}`,
|
||||
provider: req.provider ?? "",
|
||||
model: req.model,
|
||||
ai_provider_id: req.ai_provider_id,
|
||||
display_name: req.display_name || req.model,
|
||||
enabled: req.enabled ?? true,
|
||||
context_limit:
|
||||
@@ -181,10 +200,6 @@ const setupChatSpies = (state: {
|
||||
},
|
||||
);
|
||||
|
||||
// Unused but mock to avoid errors.
|
||||
spyOn(API.experimental, "deleteChatProviderConfig").mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
spyOn(API.experimental, "updateChatModelConfig").mockImplementation(
|
||||
async (modelConfigId, req) => {
|
||||
const idx = state.modelConfigs.findIndex((m) => m.id === modelConfigId);
|
||||
|
||||
Reference in New Issue
Block a user