From ca1f6b19a2d48ca019ac299072ccb74690706cc9 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 22 May 2026 09:50:01 +0200 Subject: [PATCH] feat: remove legacy chat provider tables (#25416) --- coderd/coderdtest/chat.go | 24 +- coderd/database/check_constraint.go | 86 +- coderd/database/dbauthz/dbauthz.go | 121 +- coderd/database/dbauthz/dbauthz_test.go | 108 +- coderd/database/dbgen/dbgen.go | 58 +- coderd/database/dbmetrics/querymetrics.go | 120 +- coderd/database/dbmock/dbmock.go | 222 +--- coderd/database/dump.sql | 61 +- coderd/database/foreign_key_constraint.go | 5 - .../database/legacy_chat_provider_compat.go | 44 + .../000459_provider_key_policy.down.sql | 17 +- ...rop_chat_model_config_provider_fk.down.sql | 59 +- .../000504_ai_providers_backfill.down.sql | 97 +- ...00505_ai_providers_legacy_cleanup.down.sql | 3 + .../000505_ai_providers_legacy_cleanup.up.sql | 140 +++ coderd/database/models.go | 27 - coderd/database/querier.go | 17 +- coderd/database/querier_test.go | 163 +-- coderd/database/queries.sql.go | 565 ++------- coderd/database/queries/ai_provider_keys.sql | 14 + coderd/database/queries/chatmodelconfigs.sql | 31 +- coderd/database/queries/chatproviders.sql | 102 -- .../database/queries/userchatproviderkeys.sql | 20 - coderd/database/unique_constraint.go | 4 - coderd/exp_chats.go | 1053 +++-------------- coderd/exp_chats_test.go | 617 ++-------- ...kspaceagents_chat_context_internal_test.go | 23 +- coderd/x/chatd/advisor_internal_test.go | 4 +- coderd/x/chatd/chatd.go | 114 +- coderd/x/chatd/chatd_internal_test.go | 47 +- coderd/x/chatd/chatd_test.go | 40 +- coderd/x/chatd/configcache.go | 14 +- coderd/x/chatd/configcache_test.go | 59 +- coderd/x/chatd/subagent.go | 4 +- coderd/x/chatd/subagent_internal_test.go | 54 +- coderd/x/chatd/title_override_test.go | 9 +- coderd/x/chatd/turn_summary_internal_test.go | 26 +- coderd/x/gitsync/worker_test.go | 2 +- enterprise/coderd/x/chatd/chatd_test.go | 49 +- enterprise/dbcrypt/cliutil.go | 107 +- enterprise/dbcrypt/dbcrypt.go | 150 +-- enterprise/dbcrypt/dbcrypt_internal_test.go | 119 +- scripts/check_emdash.sh | 9 +- site/src/api/api.test.ts | 9 - site/src/api/api.ts | 67 -- .../ChatModelAdminPanel.stories.tsx | 91 +- 46 files changed, 1270 insertions(+), 3505 deletions(-) create mode 100644 coderd/database/legacy_chat_provider_compat.go create mode 100644 coderd/database/migrations/000505_ai_providers_legacy_cleanup.down.sql create mode 100644 coderd/database/migrations/000505_ai_providers_legacy_cleanup.up.sql delete mode 100644 coderd/database/queries/chatproviders.sql delete mode 100644 coderd/database/queries/userchatproviderkeys.sql diff --git a/coderd/coderdtest/chat.go b/coderd/coderdtest/chat.go index 797886f068..bf460a5ff0 100644 --- a/coderd/coderdtest/chat.go +++ b/coderd/coderdtest/chat.go @@ -47,10 +47,9 @@ func FakeOpenAICompatProviderAPIKeys(t testing.TB) chatprovider.ProviderAPIKeys } // CreateOpenAICompatChatModelConfig creates the default provider and model -// config used by chat runtime tests. Tests that create chats should also set -// Options.ChatProviderAPIKeys, usually via FakeOpenAICompatProviderAPIKeys, so -// background chat work routes to a local provider until coderd closes. baseURL, -// when non-empty, is stored on the provider config. +// config used by chat runtime tests. Tests can pass a baseURL to route chat work +// to a specific local provider. If baseURL is empty, this helper starts a fake +// OpenAI-compatible provider. func CreateOpenAICompatChatModelConfig( t testing.TB, client *codersdk.ExperimentalClient, @@ -58,26 +57,19 @@ func CreateOpenAICompatChatModelConfig( ) codersdk.ChatModelConfig { t.Helper() - ctx := testutil.Context(t, testutil.WaitLong) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: TestChatProviderOpenAICompat, - APIKey: TestChatProviderAPIKey, - BaseURL: baseURL, - }) - require.NoError(t, err) - aiProviderBaseURL := baseURL - if aiProviderBaseURL == "" { - aiProviderBaseURL = "https://api.example.com/v1" + if baseURL == "" { + baseURL = chattest.OpenAI(t) } + + ctx := testutil.Context(t, testutil.WaitLong) provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ Type: codersdk.AIProviderType(TestChatProviderOpenAICompat), Name: "test-" + uuid.NewString(), - BaseURL: aiProviderBaseURL, + BaseURL: baseURL, Enabled: true, APIKeys: []string{TestChatProviderAPIKey}, }) require.NoError(t, err) - contextLimit := int64(4096) isDefault := true modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index c5e109af38..5ddfabff0e 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -6,48 +6,46 @@ type CheckConstraint string // CheckConstraint enums. const ( - CheckAiModelPricesCacheReadPriceCheck CheckConstraint = "ai_model_prices_cache_read_price_check" // ai_model_prices - CheckAiModelPricesCacheWritePriceCheck CheckConstraint = "ai_model_prices_cache_write_price_check" // ai_model_prices - CheckAiModelPricesInputPriceCheck CheckConstraint = "ai_model_prices_input_price_check" // ai_model_prices - CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices - CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers - CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys - CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs - CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs - CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers - CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers - CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config - CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config - CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config - CheckChatAclOnlyOnRootChats CheckConstraint = "chat_acl_only_on_root_chats" // chats - CheckChatGroupAclNotNullJsonb CheckConstraint = "chat_group_acl_not_null_jsonb" // chats - CheckChatUserAclNotNullJsonb CheckConstraint = "chat_user_acl_not_null_jsonb" // chats - CheckChatsPinOrderArchivedCheck CheckConstraint = "chats_pin_order_archived_check" // chats - CheckChatsPinOrderParentCheck CheckConstraint = "chats_pin_order_parent_check" // chats - CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users - CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users - CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users - CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users - CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users - CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles - CheckGroupAiBudgetsSpendLimitMicrosCheck CheckConstraint = "group_ai_budgets_spend_limit_micros_check" // group_ai_budgets - CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups - CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs - CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs - CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs - CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs - CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents - CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents - CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds - CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces - CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces - CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks - CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters - CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events - CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys - CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys - CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills - CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills - CheckUserSkillsNameFormat CheckConstraint = "user_skills_name_format" // user_skills - CheckUserSkillsNameSize CheckConstraint = "user_skills_name_size" // user_skills + CheckAiModelPricesCacheReadPriceCheck CheckConstraint = "ai_model_prices_cache_read_price_check" // ai_model_prices + CheckAiModelPricesCacheWritePriceCheck CheckConstraint = "ai_model_prices_cache_write_price_check" // ai_model_prices + CheckAiModelPricesInputPriceCheck CheckConstraint = "ai_model_prices_input_price_check" // ai_model_prices + CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices + CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers + CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys + CheckChatModelConfigsAiProviderRequiredWhenActive CheckConstraint = "chat_model_configs_ai_provider_required_when_active" // chat_model_configs + CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs + CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs + CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config + CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config + CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config + CheckChatAclOnlyOnRootChats CheckConstraint = "chat_acl_only_on_root_chats" // chats + CheckChatGroupAclNotNullJsonb CheckConstraint = "chat_group_acl_not_null_jsonb" // chats + CheckChatUserAclNotNullJsonb CheckConstraint = "chat_user_acl_not_null_jsonb" // chats + CheckChatsPinOrderArchivedCheck CheckConstraint = "chats_pin_order_archived_check" // chats + CheckChatsPinOrderParentCheck CheckConstraint = "chats_pin_order_parent_check" // chats + CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users + CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users + CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users + CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users + CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users + CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles + CheckGroupAiBudgetsSpendLimitMicrosCheck CheckConstraint = "group_ai_budgets_spend_limit_micros_check" // group_ai_budgets + CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups + CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs + CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs + CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs + CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs + CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents + CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents + CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds + CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces + CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces + CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks + CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters + CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events + CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys + CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills + CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills + CheckUserSkillsNameFormat CheckConstraint = "user_skills_name_format" // user_skills + CheckUserSkillsNameSize CheckConstraint = "user_skills_name_size" // user_skills ) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a76ced7ceb..e0d5617c1a 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1967,6 +1967,13 @@ func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) e return q.db.DeleteChatModelConfigByID(ctx, id) } +func (q *querier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAIProvider); err != nil { + return err + } + return q.db.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID) +} + func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err @@ -1974,13 +1981,6 @@ func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider return q.db.DeleteChatModelConfigsByProvider(ctx, provider) } -func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return err - } - return q.db.DeleteChatProviderByID(ctx, id) -} - func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { chat, err := q.db.GetChatByID(ctx, arg.ChatID) if err != nil { @@ -2313,17 +2313,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat return q.db.DeleteUserChatCompactionThreshold(ctx, arg) } -func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return err - } - return q.db.DeleteUserChatProviderKey(ctx, arg) -} - func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String()) if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil { @@ -2617,6 +2606,13 @@ func (q *querier) GetAIProviderKeysByProviderID(ctx context.Context, providerID return q.db.GetAIProviderKeysByProviderID(ctx, providerID) } +func (q *querier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { + return nil, err + } + return q.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs) +} + func (q *querier) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { return nil, err @@ -3128,41 +3124,6 @@ func (q *querier) GetChatPlanModeInstructions(ctx context.Context) (string, erro return q.db.GetChatPlanModeInstructions(ctx) } -func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByID(ctx, id) -} - -func (q *querier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByIDForUpdate(ctx, id) -} - -func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByProvider(ctx, provider) -} - -func (q *querier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByProviderForUpdate(ctx, provider) -} - -func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return nil, err - } - return q.db.GetChatProviders(ctx) -} - func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { _, err := q.GetChatByID(ctx, chatID) if err != nil { @@ -3403,13 +3364,6 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch return q.db.GetEnabledChatModelConfigs(ctx) } -func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return nil, err - } - return q.db.GetEnabledChatProviders(ctx) -} - func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return nil, err @@ -4661,17 +4615,6 @@ func (q *querier) GetUserChatPersonalModelOverride(ctx context.Context, arg data return q.db.GetUserChatPersonalModelOverride(ctx, arg) } -func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, userID) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil { - return nil, err - } - return q.db.GetUserChatProviderKeys(ctx, userID) -} - func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil { return 0, err @@ -5525,13 +5468,6 @@ func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.Insert return q.db.InsertChatModelConfig(ctx, arg) } -func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.InsertChatProvider(ctx, arg) -} - func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { chat, err := q.db.GetChatByID(ctx, arg.ChatID) if err != nil { @@ -6749,13 +6685,6 @@ func (q *querier) UpdateChatPlanModeByID(ctx context.Context, arg database.Updat return q.db.UpdateChatPlanModeByID(ctx, arg) } -func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.UpdateChatProvider(ctx, arg) -} - func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { // UpdateChatStatus is used by the chat processor to change chat status. // It should be called with system context. @@ -7429,17 +7358,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U return q.db.UpdateUserChatCustomPrompt(ctx, arg) } -func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return database.UserChatProviderKey{}, err - } - return q.db.UpdateUserChatProviderKey(ctx, arg) -} - func (q *querier) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { user, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -8371,17 +8289,6 @@ func (q *querier) UpsertUserChatPersonalModelOverride(ctx context.Context, arg d return q.db.UpsertUserChatPersonalModelOverride(ctx, arg) } -func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return database.UserChatProviderKey{}, err - } - return q.db.UpsertUserChatProviderKey(ctx, arg) -} - func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index d6da7333dd..354453da40 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -557,10 +557,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().DeleteChatModelConfigsByProvider(gomock.Any(), providerName).Return(nil).AnyTimes() check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - id := uuid.New() - dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes() - check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + s.Run("DeleteChatModelConfigsByAIProviderID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + providerID := uuid.New() + dbm.EXPECT().DeleteChatModelConfigsByAIProviderID(gomock.Any(), providerID).Return(nil).AnyTimes() + check.Args(providerID).Asserts(rbac.ResourceAIProvider, policy.ActionDelete) })) s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) @@ -972,34 +972,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB}) })) - s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() - check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider) - })) - s.Run("GetChatProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() - check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) - s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerName := "test-provider" - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName}) - dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes() - check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider) - })) - s.Run("GetChatProviderByProviderForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerName := "test-provider" - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName}) - dbm.EXPECT().GetChatProviderByProviderForUpdate(gomock.Any(), providerName).Return(provider, nil).AnyTimes() - check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) - s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerA := testutil.Fake(s.T(), faker, database.ChatProvider{}) - providerB := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes() - check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB}) - })) + s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { params := database.GetChatsParams{} dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes() @@ -1112,12 +1085,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB}) })) - s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerA := testutil.Fake(s.T(), faker, database.ChatProvider{}) - providerB := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes() - check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB}) - })) + s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { threshold := dbtime.Now() chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})} @@ -1166,17 +1134,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) })) - s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - arg := database.InsertChatProviderParams{ - Provider: "test-provider", - DisplayName: "Test Provider", - APIKey: "test-api-key", - Enabled: true, - } - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled}) - dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) + s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{}) @@ -1303,17 +1261,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) })) - s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - arg := database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: "Updated Provider", - APIKey: "updated-api-key", - Enabled: true, - } - dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) + s.Run("UpdateChatPinOrder", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatPinOrderParams{ @@ -2933,36 +2881,7 @@ func (s *MethodTestSuite) TestUser() { dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt") })) - s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes() - check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key}) - })) - s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()} - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns() - })) - s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"} - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) - })) - s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"} - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) - })) + s.Run("GetUserAIProviderKeyByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { u := testutil.Fake(s.T(), faker, database.User{}) arg := database.GetUserAIProviderKeyByProviderIDParams{UserID: u.ID, AIProviderID: uuid.New()} @@ -6582,6 +6501,15 @@ func (s *MethodTestSuite) TestAIBridge() { dbm.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), provider.ID).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes() check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB}) })) + s.Run("GetAIProviderKeysByProviderIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + providerA := testutil.Fake(s.T(), faker, database.AIProvider{}) + providerB := testutil.Fake(s.T(), faker, database.AIProvider{}) + providerIDs := []uuid.UUID{providerA.ID, providerB.ID} + keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerA.ID}) + keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerB.ID}) + dbm.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), providerIDs).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes() + check.Args(providerIDs).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB}) + })) s.Run("GetAIProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{}) keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{}) diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 1dcf80f2c7..200a2ed6dd 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -149,8 +149,29 @@ const ( func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelConfig, munge ...func(*database.InsertChatModelConfigParams)) database.ChatModelConfig { t.Helper() + providerName := takeFirst(seed.Provider, "openai") + aiProviderID := seed.AIProviderID + if !aiProviderID.Valid { + providers, err := db.GetAIProviders(genCtx, database.GetAIProvidersParams{IncludeDisabled: true}) + require.NoError(t, err, "get ai providers") + var provider database.AIProvider + for _, candidate := range providers { + if candidate.Type != database.AIProviderType(providerName) { + continue + } + if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) { + provider = candidate + } + } + if provider.ID == uuid.Nil { + provider = AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(providerName), + }) + } + aiProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true} + } params := database.InsertChatModelConfigParams{ - Provider: takeFirst(seed.Provider, "openai"), + Provider: providerName, Model: takeFirst(seed.Model, "gpt-4o-mini"), DisplayName: takeFirst(seed.DisplayName, "Test Model"), CreatedBy: seed.CreatedBy, @@ -160,7 +181,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit), CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold), Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)), - AIProviderID: seed.AIProviderID, + AIProviderID: aiProviderID, } for _, fn := range munge { fn(¶ms) @@ -263,9 +284,36 @@ func ChatProvider(t testing.TB, db database.Store, seed database.ChatProvider, m for _, fn := range munge { fn(¶ms) } - provider, err := db.InsertChatProvider(genCtx, params) - require.NoError(t, err, "insert chat provider") - return provider + provider := AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(params.Provider), + Name: "test-" + uuid.NewString(), + DisplayName: sql.NullString{String: params.DisplayName, Valid: params.DisplayName != ""}, + BaseUrl: params.BaseUrl, + }, func(p *database.InsertAIProviderParams) { + p.Enabled = params.Enabled + }) + if params.APIKey != "" { + AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: params.APIKey, + ApiKeyKeyID: params.ApiKeyKeyID, + }) + } + return database.ChatProvider{ + ID: provider.ID, + Provider: params.Provider, + DisplayName: params.DisplayName, + APIKey: params.APIKey, + BaseUrl: params.BaseUrl, + ApiKeyKeyID: params.ApiKeyKeyID, + CreatedBy: params.CreatedBy, + Enabled: params.Enabled, + CentralApiKeyEnabled: params.CentralApiKeyEnabled, + AllowUserApiKey: params.AllowUserApiKey, + AllowCentralApiKeyFallback: params.AllowCentralApiKeyFallback, + CreatedAt: provider.CreatedAt, + UpdatedAt: provider.UpdatedAt, + } } func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerConfig) database.MCPServerConfig { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 07960a442a..a53d7962c8 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -465,6 +465,14 @@ func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uui return r0 } +func (m queryMetricsStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + start := time.Now() + r0 := m.s.DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID) + m.queryLatencies.WithLabelValues("DeleteChatModelConfigsByAIProviderID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatModelConfigsByAIProviderID").Inc() + return r0 +} + func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error { start := time.Now() r0 := m.s.DeleteChatModelConfigsByProvider(ctx, provider) @@ -473,14 +481,6 @@ func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context, return r0 } -func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - start := time.Now() - r0 := m.s.DeleteChatProviderByID(ctx, id) - m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc() - return r0 -} - func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { start := time.Now() r0 := m.s.DeleteChatQueuedMessage(ctx, arg) @@ -809,14 +809,6 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context return r0 } -func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - start := time.Now() - r0 := m.s.DeleteUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc() - return r0 -} - func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { start := time.Now() r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg) @@ -1089,6 +1081,14 @@ func (m queryMetricsStore) GetAIProviderKeysByProviderID(ctx context.Context, pr return r0, r1 } +func (m queryMetricsStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetAIProviderKeysByProviderIDs(ctx, providerIds) + m.queryLatencies.WithLabelValues("GetAIProviderKeysByProviderIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeysByProviderIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { start := time.Now() r0, r1 := m.s.GetAIProviders(ctx, arg) @@ -1537,46 +1537,6 @@ func (m queryMetricsStore) GetChatPlanModeInstructions(ctx context.Context) (str return r0, r1 } -func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByID(ctx, id) - m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByIDForUpdate(ctx, id) - m.queryLatencies.WithLabelValues("GetChatProviderByIDForUpdate").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByIDForUpdate").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByProvider(ctx, provider) - m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByProviderForUpdate(ctx, provider) - m.queryLatencies.WithLabelValues("GetChatProviderByProviderForUpdate").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProviderForUpdate").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviders(ctx) - m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { start := time.Now() r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID) @@ -1833,14 +1793,6 @@ func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]da return r0, r1 } -func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetEnabledChatProviders(ctx) - m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { start := time.Now() r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx) @@ -3033,14 +2985,6 @@ func (m queryMetricsStore) GetUserChatPersonalModelOverride(ctx context.Context, return r0, r1 } -func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID) - m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { start := time.Now() r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg) @@ -3825,14 +3769,6 @@ func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg databa return r0, r1 } -func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.InsertChatProvider(ctx, arg) - m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc() - return r0, r1 -} - func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { start := time.Now() r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg) @@ -4881,14 +4817,6 @@ func (m queryMetricsStore) UpdateChatPlanModeByID(ctx context.Context, arg datab return r0, r1 } -func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.UpdateChatProvider(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatStatus(ctx, arg) @@ -5321,14 +5249,6 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d return r0, r1 } -func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { start := time.Now() r0, r1 := m.s.UpdateUserCodeDiffDisplayMode(ctx, arg) @@ -6105,14 +6025,6 @@ func (m queryMetricsStore) UpsertUserChatPersonalModelOverride(ctx context.Conte return r0 } -func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { start := time.Now() r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index a62f15b39b..f0dde67907 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -760,6 +760,20 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigByID(ctx, id any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigByID), ctx, id) } +// DeleteChatModelConfigsByAIProviderID mocks base method. +func (m *MockStore) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteChatModelConfigsByAIProviderID", ctx, aiProviderID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteChatModelConfigsByAIProviderID indicates an expected call of DeleteChatModelConfigsByAIProviderID. +func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByAIProviderID(ctx, aiProviderID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByAIProviderID", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByAIProviderID), ctx, aiProviderID) +} + // DeleteChatModelConfigsByProvider mocks base method. func (m *MockStore) DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error { m.ctrl.T.Helper() @@ -774,20 +788,6 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByProvider(ctx, provider return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByProvider", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByProvider), ctx, provider) } -// DeleteChatProviderByID mocks base method. -func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID. -func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id) -} - // DeleteChatQueuedMessage mocks base method. func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { m.ctrl.T.Helper() @@ -1376,20 +1376,6 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg) } -// DeleteUserChatProviderKey mocks base method. -func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey. -func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg) -} - // DeleteUserSecretByUserIDAndName mocks base method. func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { m.ctrl.T.Helper() @@ -1889,6 +1875,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderID(ctx, providerID a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderID), ctx, providerID) } +// GetAIProviderKeysByProviderIDs mocks base method. +func (m *MockStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAIProviderKeysByProviderIDs", ctx, providerIds) + ret0, _ := ret[0].([]database.AIProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAIProviderKeysByProviderIDs indicates an expected call of GetAIProviderKeysByProviderIDs. +func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderIDs(ctx, providerIds any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderIDs", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderIDs), ctx, providerIds) +} + // GetAIProviders mocks base method. func (m *MockStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { m.ctrl.T.Helper() @@ -2849,81 +2850,6 @@ func (mr *MockStoreMockRecorder) GetChatPlanModeInstructions(ctx any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).GetChatPlanModeInstructions), ctx) } -// GetChatProviderByID mocks base method. -func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByID indicates an expected call of GetChatProviderByID. -func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id) -} - -// GetChatProviderByIDForUpdate mocks base method. -func (m *MockStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByIDForUpdate", ctx, id) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByIDForUpdate indicates an expected call of GetChatProviderByIDForUpdate. -func (mr *MockStoreMockRecorder) GetChatProviderByIDForUpdate(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByIDForUpdate), ctx, id) -} - -// GetChatProviderByProvider mocks base method. -func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider. -func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider) -} - -// GetChatProviderByProviderForUpdate mocks base method. -func (m *MockStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByProviderForUpdate", ctx, provider) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByProviderForUpdate indicates an expected call of GetChatProviderByProviderForUpdate. -func (mr *MockStoreMockRecorder) GetChatProviderByProviderForUpdate(ctx, provider any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProviderForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProviderForUpdate), ctx, provider) -} - -// GetChatProviders mocks base method. -func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviders", ctx) - ret0, _ := ret[0].([]database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviders indicates an expected call of GetChatProviders. -func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx) -} - // GetChatQueuedMessages mocks base method. func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { m.ctrl.T.Helper() @@ -3404,21 +3330,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx) } -// GetEnabledChatProviders mocks base method. -func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx) - ret0, _ := ret[0].([]database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders. -func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx) -} - // GetEnabledMCPServerConfigs mocks base method. func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { m.ctrl.T.Helper() @@ -5684,21 +5595,6 @@ func (mr *MockStoreMockRecorder) GetUserChatPersonalModelOverride(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).GetUserChatPersonalModelOverride), ctx, arg) } -// GetUserChatProviderKeys mocks base method. -func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID) - ret0, _ := ret[0].([]database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys. -func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID) -} - // GetUserChatSpendInPeriod mocks base method. func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { m.ctrl.T.Helper() @@ -7183,21 +7079,6 @@ func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg) } -// InsertChatProvider mocks base method. -func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// InsertChatProvider indicates an expected call of InsertChatProvider. -func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg) -} - // InsertChatQueuedMessage mocks base method. func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { m.ctrl.T.Helper() @@ -9234,21 +9115,6 @@ func (mr *MockStoreMockRecorder) UpdateChatPlanModeByID(ctx, arg any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPlanModeByID", reflect.TypeOf((*MockStore)(nil).UpdateChatPlanModeByID), ctx, arg) } -// UpdateChatProvider mocks base method. -func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateChatProvider indicates an expected call of UpdateChatProvider. -func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg) -} - // UpdateChatStatus mocks base method. func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { m.ctrl.T.Helper() @@ -10035,21 +9901,6 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg) } -// UpdateUserChatProviderKey mocks base method. -func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey. -func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg) -} - // UpdateUserCodeDiffDisplayMode mocks base method. func (m *MockStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { m.ctrl.T.Helper() @@ -11446,21 +11297,6 @@ func (mr *MockStoreMockRecorder) UpsertUserChatPersonalModelOverride(ctx, arg an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertUserChatPersonalModelOverride), ctx, arg) } -// UpsertUserChatProviderKey mocks base method. -func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey. -func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg) -} - // UpsertWebpushVAPIDKeys mocks base method. func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index e502f69aed..2c0292ab60 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1531,30 +1531,11 @@ CREATE TABLE chat_model_configs ( compression_threshold integer NOT NULL, options jsonb DEFAULT '{}'::jsonb NOT NULL, ai_provider_id uuid, + CONSTRAINT chat_model_configs_ai_provider_required_when_active CHECK (((deleted = true) OR (ai_provider_id IS NOT NULL))), CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))), CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0)) ); -CREATE TABLE chat_providers ( - id uuid DEFAULT gen_random_uuid() NOT NULL, - provider text NOT NULL, - display_name text DEFAULT ''::text NOT NULL, - api_key text DEFAULT ''::text NOT NULL, - api_key_key_id text, - created_by uuid, - enabled boolean DEFAULT true NOT NULL, - created_at timestamp with time zone DEFAULT now() NOT NULL, - updated_at timestamp with time zone DEFAULT now() NOT NULL, - base_url text DEFAULT ''::text NOT NULL, - central_api_key_enabled boolean DEFAULT true NOT NULL, - allow_user_api_key boolean DEFAULT false NOT NULL, - allow_central_api_key_fallback boolean DEFAULT false NOT NULL, - CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))), - CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key)))) -); - -COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted'; - CREATE TABLE chat_queued_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, @@ -3040,17 +3021,6 @@ COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to a COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.'; -CREATE TABLE user_chat_provider_keys ( - id uuid DEFAULT gen_random_uuid() NOT NULL, - user_id uuid NOT NULL, - chat_provider_id uuid NOT NULL, - api_key text NOT NULL, - api_key_key_id text, - created_at timestamp with time zone DEFAULT now() NOT NULL, - updated_at timestamp with time zone DEFAULT now() NOT NULL, - CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text)) -); - CREATE TABLE user_configs ( user_id uuid NOT NULL, key character varying(256) NOT NULL, @@ -3667,12 +3637,6 @@ ALTER TABLE ONLY chat_messages ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); - -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider); - ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id); @@ -3889,12 +3853,6 @@ ALTER TABLE ONLY user_ai_provider_keys ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); - ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); @@ -4112,8 +4070,6 @@ CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING b CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false)); -CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled); - CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id); CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL); @@ -4444,12 +4400,6 @@ ALTER TABLE ONLY chat_model_configs ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); - ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; @@ -4687,15 +4637,6 @@ ALTER TABLE ONLY user_ai_provider_keys ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; - ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 27eadbe88f..cc2c1ac485 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -24,8 +24,6 @@ const ( ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); - ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; @@ -105,9 +103,6 @@ const ( ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; - ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; - ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/legacy_chat_provider_compat.go b/coderd/database/legacy_chat_provider_compat.go new file mode 100644 index 0000000000..7749937987 --- /dev/null +++ b/coderd/database/legacy_chat_provider_compat.go @@ -0,0 +1,44 @@ +package database + +import ( + "database/sql" + "time" + + "github.com/google/uuid" +) + +// ChatProvider is the fixture shape accepted by dbgen.ChatProvider. +// +//nolint:revive +type ChatProvider struct { + ID uuid.UUID + Provider string + DisplayName string + APIKey string + BaseUrl string + ApiKeyKeyID sql.NullString + CreatedAt time.Time + UpdatedAt time.Time + CreatedBy uuid.NullUUID + Enabled bool + CentralApiKeyEnabled bool + AllowUserApiKey bool + AllowCentralApiKeyFallback bool +} + +// InsertChatProviderParams is the callback parameter shape accepted by +// dbgen.ChatProvider. +// +//nolint:revive +type InsertChatProviderParams struct { + Provider string + DisplayName string + APIKey string + BaseUrl string + ApiKeyKeyID sql.NullString + CreatedBy uuid.NullUUID + Enabled bool + CentralApiKeyEnabled bool + AllowUserApiKey bool + AllowCentralApiKeyFallback bool +} diff --git a/coderd/database/migrations/000459_provider_key_policy.down.sql b/coderd/database/migrations/000459_provider_key_policy.down.sql index b7a5bc2a55..7e5e9c2047 100644 --- a/coderd/database/migrations/000459_provider_key_policy.down.sql +++ b/coderd/database/migrations/000459_provider_key_policy.down.sql @@ -1,8 +1,15 @@ DROP TABLE IF EXISTS user_chat_provider_keys; -ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy; +DO $$ +BEGIN + IF to_regclass('chat_providers') IS NULL THEN + RETURN; + END IF; -ALTER TABLE chat_providers - DROP COLUMN IF EXISTS central_api_key_enabled, - DROP COLUMN IF EXISTS allow_user_api_key, - DROP COLUMN IF EXISTS allow_central_api_key_fallback; + ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy; + + ALTER TABLE chat_providers + DROP COLUMN IF EXISTS central_api_key_enabled, + DROP COLUMN IF EXISTS allow_user_api_key, + DROP COLUMN IF EXISTS allow_central_api_key_fallback; +END $$; diff --git a/coderd/database/migrations/000474_drop_chat_model_config_provider_fk.down.sql b/coderd/database/migrations/000474_drop_chat_model_config_provider_fk.down.sql index 3b5ce550ce..98997ffe4c 100644 --- a/coderd/database/migrations/000474_drop_chat_model_config_provider_fk.down.sql +++ b/coderd/database/migrations/000474_drop_chat_model_config_provider_fk.down.sql @@ -1,27 +1,34 @@ --- Restore placeholder provider rows before re-adding the provider FK. --- --- The companion up migration dropped chat_model_configs.provider's foreign --- key, so historical model-config rows can outlive a deleted provider row. --- These backfilled providers are deliberately disabled stubs with empty --- credential fields, which lets rollback restore referential integrity --- without re-enabling a provider. This insert depends on the current --- provider whitelist still admitting every historical --- chat_model_configs.provider value, and on the omitted columns keeping --- compatible defaults. Operators restoring a real provider should update the --- stub row, including credential-policy flags such as --- central_api_key_enabled, before enabling it, rather than insert a second --- row with the same provider name. -INSERT INTO chat_providers (provider, enabled) -SELECT DISTINCT - cmc.provider, - FALSE -FROM - chat_model_configs cmc -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider -WHERE - cp.provider IS NULL; +DO $$ +BEGIN + IF to_regclass('chat_providers') IS NULL THEN + RETURN; + END IF; -ALTER TABLE chat_model_configs - ADD CONSTRAINT chat_model_configs_provider_fkey - FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE; + -- Restore placeholder provider rows before re-adding the provider FK. + -- + -- The companion up migration dropped chat_model_configs.provider's foreign + -- key, so historical model-config rows can outlive a deleted provider row. + -- These backfilled providers are deliberately disabled stubs with empty + -- credential fields, which lets rollback restore referential integrity + -- without re-enabling a provider. This insert depends on the current + -- provider whitelist still admitting every historical + -- chat_model_configs.provider value, and on the omitted columns keeping + -- compatible defaults. Operators restoring a real provider should update the + -- stub row, including credential-policy flags such as + -- central_api_key_enabled, before enabling it, rather than insert a second + -- row with the same provider name. + INSERT INTO chat_providers (provider, enabled) + SELECT DISTINCT + cmc.provider, + FALSE + FROM + chat_model_configs cmc + LEFT JOIN + chat_providers cp ON cp.provider = cmc.provider + WHERE + cp.provider IS NULL; + + ALTER TABLE chat_model_configs + ADD CONSTRAINT chat_model_configs_provider_fkey + FOREIGN KEY (provider) REFERENCES chat_providers(provider) ON DELETE CASCADE; +END $$; diff --git a/coderd/database/migrations/000504_ai_providers_backfill.down.sql b/coderd/database/migrations/000504_ai_providers_backfill.down.sql index eb99d53906..af85461509 100644 --- a/coderd/database/migrations/000504_ai_providers_backfill.down.sql +++ b/coderd/database/migrations/000504_ai_providers_backfill.down.sql @@ -1,48 +1,55 @@ -WITH migrated_provider_ids AS ( - SELECT id - FROM chat_providers - UNION - SELECT id - FROM ai_providers - WHERE name LIKE 'agents-%' - AND deleted = TRUE -) -UPDATE chat_model_configs -SET ai_provider_id = NULL -WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); +DO $$ +BEGIN + IF to_regclass('chat_providers') IS NULL THEN + RETURN; + END IF; -WITH migrated_provider_ids AS ( - SELECT id - FROM chat_providers - UNION - SELECT id - FROM ai_providers - WHERE name LIKE 'agents-%' - AND deleted = TRUE -) -DELETE FROM user_ai_provider_keys -WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); + WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE + ) + UPDATE chat_model_configs + SET ai_provider_id = NULL + WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); -WITH migrated_provider_ids AS ( - SELECT id - FROM chat_providers - UNION - SELECT id - FROM ai_providers - WHERE name LIKE 'agents-%' - AND deleted = TRUE -) -DELETE FROM ai_provider_keys -WHERE provider_id IN (SELECT id FROM migrated_provider_ids); + WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE + ) + DELETE FROM user_ai_provider_keys + WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); -WITH migrated_provider_ids AS ( - SELECT id - FROM chat_providers - UNION - SELECT id - FROM ai_providers - WHERE name LIKE 'agents-%' - AND deleted = TRUE -) -DELETE FROM ai_providers -WHERE id IN (SELECT id FROM migrated_provider_ids); + WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE + ) + DELETE FROM ai_provider_keys + WHERE provider_id IN (SELECT id FROM migrated_provider_ids); + + WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE + ) + DELETE FROM ai_providers + WHERE id IN (SELECT id FROM migrated_provider_ids); +END $$; diff --git a/coderd/database/migrations/000505_ai_providers_legacy_cleanup.down.sql b/coderd/database/migrations/000505_ai_providers_legacy_cleanup.down.sql new file mode 100644 index 0000000000..793981b9e9 --- /dev/null +++ b/coderd/database/migrations/000505_ai_providers_legacy_cleanup.down.sql @@ -0,0 +1,3 @@ +-- no-op. Legacy chat provider tables are intentionally not recreated from AI +-- provider definitions. Rolling back past this migration is not reversible at +-- the schema level. diff --git a/coderd/database/migrations/000505_ai_providers_legacy_cleanup.up.sql b/coderd/database/migrations/000505_ai_providers_legacy_cleanup.up.sql new file mode 100644 index 0000000000..87591c6ee6 --- /dev/null +++ b/coderd/database/migrations/000505_ai_providers_legacy_cleanup.up.sql @@ -0,0 +1,140 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM chat_providers cp + JOIN ai_providers ap ON ap.name = 'agents-' || cp.provider + WHERE ap.deleted = FALSE + AND ap.id != cp.id + ) THEN + RAISE EXCEPTION 'cannot finalize chat provider migration because a live agents-* AI provider name already exists'; + END IF; +END $$; + +INSERT INTO ai_providers ( + id, + type, + name, + display_name, + enabled, + base_url, + created_at, + updated_at +) +SELECT + cp.id, + cp.provider::ai_provider_type, + 'agents-' || cp.provider, + NULLIF(cp.display_name, ''), + cp.enabled, + cp.base_url, + cp.created_at, + cp.updated_at +FROM chat_providers cp +WHERE NOT EXISTS ( + SELECT 1 + FROM ai_providers ap + WHERE ap.id = cp.id +); + +UPDATE ai_providers ap +SET + type = cp.provider::ai_provider_type, + name = 'agents-' || cp.provider, + display_name = NULLIF(cp.display_name, ''), + enabled = cp.enabled, + deleted = FALSE, + base_url = cp.base_url, + updated_at = GREATEST(cp.updated_at, ap.updated_at) +FROM chat_providers cp +WHERE ap.id = cp.id + AND (cp.updated_at > ap.updated_at OR ap.deleted); + +DELETE FROM ai_provider_keys apk +USING chat_providers cp +WHERE cp.id = apk.provider_id + AND cp.api_key = '' + AND cp.updated_at > apk.updated_at; + +WITH runtime_provider_keys AS ( + SELECT DISTINCT ON (apk.provider_id) + apk.id, + apk.provider_id + FROM ai_provider_keys apk + JOIN chat_providers cp ON cp.id = apk.provider_id + WHERE cp.api_key != '' + ORDER BY + apk.provider_id ASC, + apk.created_at ASC, + apk.id ASC +) +UPDATE ai_provider_keys apk +SET + api_key = cp.api_key, + api_key_key_id = cp.api_key_key_id, + updated_at = cp.updated_at +FROM runtime_provider_keys rpk +JOIN chat_providers cp ON cp.id = rpk.provider_id +WHERE apk.id = rpk.id + AND cp.updated_at > apk.updated_at; + +INSERT INTO ai_provider_keys ( + id, + provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) +SELECT + gen_random_uuid(), + cp.id, + cp.api_key, + cp.api_key_key_id, + cp.updated_at, + cp.updated_at +FROM chat_providers cp +WHERE cp.api_key != '' + AND NOT EXISTS ( + SELECT 1 + FROM ai_provider_keys apk + WHERE apk.provider_id = cp.id + ); + +INSERT INTO user_ai_provider_keys ( + id, + user_id, + ai_provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) +SELECT + ucpk.id, + ucpk.user_id, + ucpk.chat_provider_id, + ucpk.api_key, + ucpk.api_key_key_id, + ucpk.created_at, + ucpk.updated_at +FROM user_chat_provider_keys ucpk +ON CONFLICT (user_id, ai_provider_id) DO UPDATE +SET + api_key = EXCLUDED.api_key, + api_key_key_id = EXCLUDED.api_key_key_id, + updated_at = EXCLUDED.updated_at +WHERE user_ai_provider_keys.updated_at < EXCLUDED.updated_at; + +UPDATE chat_model_configs cmc +SET ai_provider_id = cp.id +FROM chat_providers cp +WHERE cmc.provider = cp.provider + AND cmc.ai_provider_id IS NULL; + +ALTER TABLE chat_model_configs + ADD CONSTRAINT chat_model_configs_ai_provider_required_when_active + CHECK (deleted = TRUE OR ai_provider_id IS NOT NULL); + +DROP TABLE IF EXISTS user_chat_provider_keys; +DROP TABLE IF EXISTS chat_providers; diff --git a/coderd/database/models.go b/coderd/database/models.go index 169080f8fc..080e9ae027 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4682,23 +4682,6 @@ type ChatModelConfig struct { AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"` } -type ChatProvider struct { - ID uuid.UUID `db:"id" json:"id"` - Provider string `db:"provider" json:"provider"` - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - // The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - BaseUrl string `db:"base_url" json:"base_url"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` -} - type ChatQueuedMessage struct { ID int64 `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` @@ -5706,16 +5689,6 @@ type UserAiProviderKey struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } -type UserChatProviderKey struct { - ID uuid.UUID `db:"id" json:"id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` -} - type UserConfig struct { UserID uuid.UUID `db:"user_id" json:"user_id"` Key string `db:"key" json:"key"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6c89ffe352..92566c8074 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -122,8 +122,8 @@ type sqlcQuerier interface { // archive-cleanup retry). DeleteChatDebugDataByChatID(ctx context.Context, arg DeleteChatDebugDataByChatIDParams) (int64, error) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error + DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error - DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error @@ -196,7 +196,6 @@ type sqlcQuerier interface { DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error - DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (UserSecret, error) DeleteUserSkillByUserIDAndName(ctx context.Context, arg DeleteUserSkillByUserIDAndNameParams) (UserSkill, error) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error @@ -272,6 +271,9 @@ type sqlcQuerier interface { // key per provider; multiple keys are stored to support future // failover and rotation flows. GetAIProviderKeysByProviderID(ctx context.Context, providerID uuid.UUID) ([]AIProviderKey, error) + // Returns all keys for the requested providers, ordered by provider then created_at ASC + // so callers can select the oldest non-empty key per provider without issuing N queries. + GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) // Returns AI provider rows. Soft-deleted and disabled rows are excluded // unless include_deleted or include_disabled is set. GetAIProviders(ctx context.Context, arg GetAIProvidersParams) ([]AIProvider, error) @@ -384,11 +386,6 @@ type sqlcQuerier interface { // personal chat model overrides. It defaults to false when unset. GetChatPersonalModelOverridesEnabled(ctx context.Context) (bool, error) GetChatPlanModeInstructions(ctx context.Context) (string, error) - GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) - GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) - GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) - GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) - GetChatProviders(ctx context.Context) ([]ChatProvider, error) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error) // Returns the chat retention period in days. Chats archived longer // than this and orphaned chat files older than this are purged by @@ -452,7 +449,6 @@ type sqlcQuerier interface { // Check both to ensure the selected config is actually usable. GetEnabledChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error) GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) - GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) @@ -760,7 +756,6 @@ type sqlcQuerier interface { GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) GetUserChatPersonalModelOverride(ctx context.Context, arg GetUserChatPersonalModelOverrideParams) (string, error) - GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) // Returns the total spend for a user in the given period. // When organization_id is NULL, spend across all organizations is // returned (global behavior). Otherwise only spend within the @@ -932,7 +927,6 @@ type sqlcQuerier interface { InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) - InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error) @@ -1214,7 +1208,6 @@ type sqlcQuerier interface { UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error) - UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) UpdateChatTitleByID(ctx context.Context, arg UpdateChatTitleByIDParams) (Chat, error) @@ -1283,7 +1276,6 @@ type sqlcQuerier interface { UpdateUserAgentChatSendShortcut(ctx context.Context, arg UpdateUserAgentChatSendShortcutParams) (string, error) UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error) UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error) - UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg UpdateUserCodeDiffDisplayModeParams) (string, error) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error @@ -1408,7 +1400,6 @@ type sqlcQuerier interface { UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error UpsertUserChatPersonalModelOverride(ctx context.Context, arg UpsertUserChatPersonalModelOverrideParams) error - UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error) UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 5eb5ab1502..c874ebcbad 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -10609,20 +10609,12 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) { }, func(params *database.InsertChatModelConfigParams) { params.Enabled = false }) - legacyProvider := dbgen.ChatProvider(t, store, database.ChatProvider{Provider: "google"}) - legacyConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{ - Provider: legacyProvider.Provider, - Model: "google-model-" + uuid.NewString(), - }) configs, err := store.GetEnabledChatModelConfigs(ctx) require.NoError(t, err) require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool { return config.ID == enabledConfig.ID })) - require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool { - return config.ID == legacyConfig.ID - })) require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool { return config.ID == disabledProviderConfig.ID })) @@ -10636,6 +10628,46 @@ func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) { _, err = store.GetEnabledChatModelConfigByID(ctx, disabledProviderConfig.ID) require.ErrorIs(t, err, sql.ErrNoRows) + + _, err = store.GetEnabledChatModelConfigByID(ctx, disabledModelConfig.ID) + require.ErrorIs(t, err, sql.ErrNoRows) +} + +func insertChatModelConfigForTest( + ctx context.Context, + t testing.TB, + store database.Store, + params database.InsertChatModelConfigParams, +) (database.ChatModelConfig, error) { + t.Helper() + if params.AIProviderID.Valid { + return store.InsertChatModelConfig(ctx, params) + } + providerName := params.Provider + if providerName == "" { + providerName = "openai" + params.Provider = providerName + } + providers, err := store.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) + if err != nil { + return database.ChatModelConfig{}, err + } + var provider database.AIProvider + for _, candidate := range providers { + if candidate.Type != database.AIProviderType(providerName) { + continue + } + if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) { + provider = candidate + } + } + if provider.ID == uuid.Nil { + provider = dbgen.AIProvider(t, store, database.AIProvider{ + Type: database.AIProviderType(providerName), + }) + } + params.AIProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true} + return store.InsertChatModelConfig(ctx, params) } func TestInsertChatMessages(t *testing.T) { @@ -10653,7 +10685,7 @@ func TestInsertChatMessages(t *testing.T) { ) database.ChatModelConfig { t.Helper() - modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelConfig, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: provider, Model: model, DisplayName: displayName, @@ -10681,14 +10713,13 @@ func TestInsertChatMessages(t *testing.T) { dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) provider := "openai" - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: provider, DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) modelConfigA := insertModelConfig( t, @@ -10850,18 +10881,21 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - // A chat_providers row is required as a FK for model configs. - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + // An AI provider row is required as a FK for model configs. + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-" + uuid.NewString(), + DisplayName: sql.NullString{String: "OpenAI", Valid: true}, + Enabled: true, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: "test-key", }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Model: "test-model", DisplayName: "Test Model", CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, @@ -11227,16 +11261,15 @@ func TestGetPRInsights(t *testing.T) { user := dbgen.User(t, store, database.User{}) dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: "anthropic", DisplayName: "Anthropic", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + mc, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: "anthropic", Model: "claude-4", DisplayName: "Claude 4", @@ -11683,7 +11716,7 @@ func TestGetPRInsights(t *testing.T) { store, userID, _, orgID := setupChatInfra(t) const modelName = "claude-4.1" - emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{ + emptyDisplayModel, err := insertChatModelConfigForTest(context.Background(), t, store, database.InsertChatModelConfigParams{ Provider: "anthropic", Model: modelName, DisplayName: "", @@ -11791,16 +11824,15 @@ func TestChatPinOrderQueries(t *testing.T) { // Use background context for fixture setup so the // timed test context doesn't tick during DB init. bg := context.Background() - _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -11972,16 +12004,15 @@ func TestChatPinOrderConstraints(t *testing.T) { dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) bg := context.Background() - _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -12065,16 +12096,15 @@ func TestChatLabels(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) - _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -12365,16 +12395,15 @@ func TestUpdateChatLastTurnSummary(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) - _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -12466,16 +12495,15 @@ func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) { providerName := "openai" modelName := "debug-model-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -12659,16 +12687,15 @@ func TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls(t *te providerName := "openai" modelName := "debug-model-step-boundaries-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -12917,16 +12944,15 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { providerName := "openai" modelName := "debug-model-finalize-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13356,16 +13382,15 @@ func TestChatDebugSQLGuards(t *testing.T) { providerName := "openai" modelName := "debug-model-guards-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13490,16 +13515,15 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) { providerName := "openai" modelName := "debug-model-coalesce-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13605,16 +13629,15 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) { providerName := "openai" modelName := "debug-step-coalesce-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13730,16 +13753,15 @@ func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) { providerName := "openai" modelName := "debug-model-null-msg-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13828,16 +13850,15 @@ func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testi providerName := "openai" modelName := "debug-model-started-before-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13940,16 +13961,15 @@ func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T) providerName := "openai" modelName := "debug-model-by-chat-started-before-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -14029,17 +14049,13 @@ func TestGetChatsFilter(t *testing.T) { user := dbgen.User(t, store, database.User{}) dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, - }) - require.NoError(t, err) + provider := dbgen.AIProviderWithOptionalKey(t, store, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + }, "test-key") modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Model: "test-model-" + uuid.NewString(), DisplayName: "Test Model", CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, @@ -14265,16 +14281,15 @@ func TestChatHasUnread(t *testing.T) { user := dbgen.User(t, store, database.User{}) dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, }) - require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model-" + uuid.NewString(), DisplayName: "Test Model", diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1fd40ec3fa..571b3cb142 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -276,6 +276,51 @@ func (q *sqlQuerier) GetAIProviderKeysByProviderID(ctx context.Context, provider return items, nil } +const getAIProviderKeysByProviderIDs = `-- name: GetAIProviderKeysByProviderIDs :many +SELECT + id, provider_id, api_key, api_key_key_id, created_at, updated_at +FROM + ai_provider_keys +WHERE + provider_id = ANY($1::uuid[]) +ORDER BY + provider_id ASC, + created_at ASC, + id ASC +` + +// Returns all keys for the requested providers, ordered by provider then created_at ASC +// so callers can select the oldest non-empty key per provider without issuing N queries. +func (q *sqlQuerier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getAIProviderKeysByProviderIDs, pq.Array(providerIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AIProviderKey + for rows.Next() { + var i AIProviderKey + if err := rows.Scan( + &i.ID, + &i.ProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertAIProviderKey = `-- name: InsertAIProviderKey :one INSERT INTO ai_provider_keys ( id, @@ -5020,6 +5065,23 @@ func (q *sqlQuerier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID return err } +const deleteChatModelConfigsByAIProviderID = `-- name: DeleteChatModelConfigsByAIProviderID :exec +UPDATE + chat_model_configs +SET + deleted = TRUE, + deleted_at = NOW(), + updated_at = NOW() +WHERE + ai_provider_id = $1::uuid + AND deleted = FALSE +` + +func (q *sqlQuerier) DeleteChatModelConfigsByAIProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteChatModelConfigsByAIProviderID, aiProviderID) + return err +} + const deleteChatModelConfigsByProvider = `-- name: DeleteChatModelConfigsByProvider :exec UPDATE chat_model_configs @@ -5164,18 +5226,14 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -LEFT JOIN +JOIN ai_providers ap ON ap.id = cmc.ai_provider_id -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL WHERE cmc.id = $1::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ` // Providers can be disabled independently of their model configs. @@ -5209,17 +5267,13 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -LEFT JOIN +JOIN ai_providers ap ON ap.id = cmc.ai_provider_id -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ORDER BY cmc.provider ASC, cmc.model ASC, @@ -5435,369 +5489,6 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo return i, err } -const deleteChatProviderByID = `-- name: DeleteChatProviderByID :exec -DELETE FROM - chat_providers -WHERE - id = $1::uuid -` - -func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - _, err := q.db.ExecContext(ctx, deleteChatProviderByID, id) - return err -} - -const getChatProviderByID = `-- name: GetChatProviderByID :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - id = $1::uuid -` - -func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByID, id) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByIDForUpdate = `-- name: GetChatProviderByIDForUpdate :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - id = $1::uuid -FOR UPDATE -` - -func (q *sqlQuerier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByIDForUpdate, id) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - provider = $1::text -` - -func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByProvider, provider) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByProviderForUpdate = `-- name: GetChatProviderByProviderForUpdate :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - provider = $1::text -FOR UPDATE -` - -func (q *sqlQuerier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByProviderForUpdate, provider) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviders = `-- name: GetChatProviders :many -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -ORDER BY - provider ASC -` - -func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, error) { - rows, err := q.db.QueryContext(ctx, getChatProviders) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatProvider - for rows.Next() { - var i ChatProvider - if err := rows.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - enabled = TRUE -ORDER BY - provider ASC -` - -func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) { - rows, err := q.db.QueryContext(ctx, getEnabledChatProviders) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatProvider - for rows.Next() { - var i ChatProvider - if err := rows.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const insertChatProvider = `-- name: InsertChatProvider :one -INSERT INTO chat_providers ( - provider, - display_name, - api_key, - base_url, - api_key_key_id, - created_by, - enabled, - central_api_key_enabled, - allow_user_api_key, - allow_central_api_key_fallback -) VALUES ( - $1::text, - $2::text, - $3::text, - $4::text, - $5::text, - $6::uuid, - $7::boolean, - $8::boolean, - $9::boolean, - $10::boolean -) -RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -` - -type InsertChatProviderParams struct { - Provider string `db:"provider" json:"provider"` - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` -} - -func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, insertChatProvider, - arg.Provider, - arg.DisplayName, - arg.APIKey, - arg.BaseUrl, - arg.ApiKeyKeyID, - arg.CreatedBy, - arg.Enabled, - arg.CentralApiKeyEnabled, - arg.AllowUserApiKey, - arg.AllowCentralApiKeyFallback, - ) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const updateChatProvider = `-- name: UpdateChatProvider :one -UPDATE - chat_providers -SET - display_name = $1::text, - api_key = $2::text, - base_url = $3::text, - api_key_key_id = $4::text, - enabled = $5::boolean, - central_api_key_enabled = $6::boolean, - allow_user_api_key = $7::boolean, - allow_central_api_key_fallback = $8::boolean, - updated_at = NOW() -WHERE - id = $9::uuid -RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -` - -type UpdateChatProviderParams struct { - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - Enabled bool `db:"enabled" json:"enabled"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` - ID uuid.UUID `db:"id" json:"id"` -} - -func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, updateChatProvider, - arg.DisplayName, - arg.APIKey, - arg.BaseUrl, - arg.ApiKeyKeyID, - arg.Enabled, - arg.CentralApiKeyEnabled, - arg.AllowUserApiKey, - arg.AllowCentralApiKeyFallback, - arg.ID, - ) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - const acquireChats = `-- name: AcquireChats :many WITH acquired_chats AS ( UPDATE @@ -27555,126 +27246,6 @@ func (q *sqlQuerier) UpdateUserSkillByUserIDAndName(ctx context.Context, arg Upd return i, err } -const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec -DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2 -` - -type DeleteUserChatProviderKeyParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` -} - -func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error { - _, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID) - return err -} - -const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many -SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC -` - -func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) { - rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []UserChatProviderKey - for rows.Next() { - var i UserChatProviderKey - if err := rows.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one -UPDATE user_chat_provider_keys -SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW() -WHERE user_id = $3 AND chat_provider_id = $4 -RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at -` - -type UpdateUserChatProviderKeyParams struct { - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` -} - -func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) { - row := q.db.QueryRowContext(ctx, updateUserChatProviderKey, - arg.APIKey, - arg.ApiKeyKeyID, - arg.UserID, - arg.ChatProviderID, - ) - var i UserChatProviderKey - err := row.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one -INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) -VALUES ($1, $2, $3, $4::text) -ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET - api_key = $3, - api_key_key_id = $4::text, - updated_at = NOW() -RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at -` - -type UpsertUserChatProviderKeyParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` -} - -func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) { - row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey, - arg.UserID, - arg.ChatProviderID, - arg.APIKey, - arg.ApiKeyKeyID, - ) - var i UserChatProviderKey - err := row.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const allUserIDs = `-- name: AllUserIDs :many SELECT DISTINCT id FROM USERS WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END diff --git a/coderd/database/queries/ai_provider_keys.sql b/coderd/database/queries/ai_provider_keys.sql index 178c888079..d15fe6e4be 100644 --- a/coderd/database/queries/ai_provider_keys.sql +++ b/coderd/database/queries/ai_provider_keys.sql @@ -32,6 +32,20 @@ WHERE ORDER BY provider_id ASC; +-- name: GetAIProviderKeysByProviderIDs :many +-- Returns all keys for the requested providers, ordered by provider then created_at ASC +-- so callers can select the oldest non-empty key per provider without issuing N queries. +SELECT + * +FROM + ai_provider_keys +WHERE + provider_id = ANY(@provider_ids::uuid[]) +ORDER BY + provider_id ASC, + created_at ASC, + id ASC; + -- name: GetAIProviderKeys :many -- Returns AI provider key rows. By default, only rows whose parent -- provider is live (deleted = FALSE) are returned, so the API list diff --git a/coderd/database/queries/chatmodelconfigs.sql b/coderd/database/queries/chatmodelconfigs.sql index 1039357e99..2cf93698cf 100644 --- a/coderd/database/queries/chatmodelconfigs.sql +++ b/coderd/database/queries/chatmodelconfigs.sql @@ -34,17 +34,13 @@ SELECT cmc.* FROM chat_model_configs cmc -LEFT JOIN +JOIN ai_providers ap ON ap.id = cmc.ai_provider_id -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ORDER BY cmc.provider ASC, cmc.model ASC, @@ -58,18 +54,14 @@ FROM chat_model_configs cmc -- Providers can be disabled independently of their model configs. -- Check both to ensure the selected config is actually usable. -LEFT JOIN +JOIN ai_providers ap ON ap.id = cmc.ai_provider_id -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL WHERE cmc.id = @id::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ); + AND ap.enabled = TRUE + AND ap.deleted = FALSE; -- name: InsertChatModelConfig :one INSERT INTO chat_model_configs ( @@ -151,3 +143,14 @@ SET WHERE provider = @provider::text AND deleted = FALSE; + +-- name: DeleteChatModelConfigsByAIProviderID :exec +UPDATE + chat_model_configs +SET + deleted = TRUE, + deleted_at = NOW(), + updated_at = NOW() +WHERE + ai_provider_id = @ai_provider_id::uuid + AND deleted = FALSE; diff --git a/coderd/database/queries/chatproviders.sql b/coderd/database/queries/chatproviders.sql deleted file mode 100644 index 7df983541d..0000000000 --- a/coderd/database/queries/chatproviders.sql +++ /dev/null @@ -1,102 +0,0 @@ --- name: GetChatProviderByID :one -SELECT - * -FROM - chat_providers -WHERE - id = @id::uuid; - --- name: GetChatProviderByIDForUpdate :one -SELECT - * -FROM - chat_providers -WHERE - id = @id::uuid -FOR UPDATE; - --- name: GetChatProviderByProvider :one -SELECT - * -FROM - chat_providers -WHERE - provider = @provider::text; - --- name: GetChatProviderByProviderForUpdate :one -SELECT - * -FROM - chat_providers -WHERE - provider = @provider::text -FOR UPDATE; - --- name: GetChatProviders :many -SELECT - * -FROM - chat_providers -ORDER BY - provider ASC; - --- name: GetEnabledChatProviders :many -SELECT - * -FROM - chat_providers -WHERE - enabled = TRUE -ORDER BY - provider ASC; - --- name: InsertChatProvider :one -INSERT INTO chat_providers ( - provider, - display_name, - api_key, - base_url, - api_key_key_id, - created_by, - enabled, - central_api_key_enabled, - allow_user_api_key, - allow_central_api_key_fallback -) VALUES ( - @provider::text, - @display_name::text, - @api_key::text, - @base_url::text, - sqlc.narg('api_key_key_id')::text, - sqlc.narg('created_by')::uuid, - @enabled::boolean, - @central_api_key_enabled::boolean, - @allow_user_api_key::boolean, - @allow_central_api_key_fallback::boolean -) -RETURNING - *; - --- name: UpdateChatProvider :one -UPDATE - chat_providers -SET - display_name = @display_name::text, - api_key = @api_key::text, - base_url = @base_url::text, - api_key_key_id = sqlc.narg('api_key_key_id')::text, - enabled = @enabled::boolean, - central_api_key_enabled = @central_api_key_enabled::boolean, - allow_user_api_key = @allow_user_api_key::boolean, - allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean, - updated_at = NOW() -WHERE - id = @id::uuid -RETURNING - *; - --- name: DeleteChatProviderByID :exec -DELETE FROM - chat_providers -WHERE - id = @id::uuid; diff --git a/coderd/database/queries/userchatproviderkeys.sql b/coderd/database/queries/userchatproviderkeys.sql deleted file mode 100644 index 38c177156e..0000000000 --- a/coderd/database/queries/userchatproviderkeys.sql +++ /dev/null @@ -1,20 +0,0 @@ --- name: GetUserChatProviderKeys :many -SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC; - --- name: UpsertUserChatProviderKey :one -INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) -VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text) -ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET - api_key = @api_key, - api_key_key_id = sqlc.narg('api_key_key_id')::text, - updated_at = NOW() -RETURNING *; - --- name: UpdateUserChatProviderKey :one -UPDATE user_chat_provider_keys -SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW() -WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id -RETURNING *; - --- name: DeleteUserChatProviderKey :exec -DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 1afd078b8b..104b842ea8 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -25,8 +25,6 @@ const ( UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); - UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); - UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider); UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id); UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id); UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton); @@ -99,8 +97,6 @@ const ( UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); - UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); - UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id); UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type); diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index d1f82db663..2fba5db055 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -11,7 +11,6 @@ import ( "mime" "net/http" "net/http/httptest" - "net/url" "slices" "strconv" "strings" @@ -750,7 +749,9 @@ type userChatModelAvailability struct { configuredModels []chatprovider.ConfiguredModel enabledModels []database.ChatModelConfig providerStatus map[string]chatprovider.ProviderAvailability + providerStatusByID map[uuid.UUID]chatprovider.ProviderAvailability enabledProviderNames map[string]struct{} + enabledProviderIDs map[uuid.UUID]struct{} } // chatModelConfigUnavailableReason reports why a model config cannot be used. @@ -765,78 +766,132 @@ const ( chatModelConfigUnavailableCredentialsMissing chatModelConfigUnavailableReason = "credentials_missing" ) -// getUserChatProviderAvailability returns chat provider availability for a -// user. Deployment-level enabled providers and models are read with -// dbauthz.AsSystemRestricted(ctx) because they are global chat configuration, -// not user-owned resources. Callers must pass an authenticated user context so -// user-scoped model checks and provider-key lookups run under the caller's -// authorization. The returned struct contains configured providers and models -// for catalog listing, enabled model rows for ID validation, resolved provider -// status, and normalized enabled-provider membership. +// getUserChatProviderAvailability returns the enabled chat providers and models +// the user can access. Deployment-level configuration is read as chatd, while +// user key lookups still use the caller's authorization context. func (api *API) getUserChatProviderAvailability( ctx context.Context, userID uuid.UUID, ) (userChatModelAvailability, error) { - //nolint:gocritic // System context is required to read enabled chat config. - systemCtx := dbauthz.AsSystemRestricted(ctx) - enabledProviders, err := api.Database.GetEnabledChatProviders(systemCtx) + //nolint:gocritic // Chatd context is required to read enabled chat config. + chatdCtx := dbauthz.AsChatd(ctx) + enabledProviders, err := api.Database.GetAIProviders(chatdCtx, database.GetAIProvidersParams{}) if err != nil { return userChatModelAvailability{}, err } - enabledModels, err := api.Database.GetEnabledChatModelConfigs(systemCtx) + enabledModels, err := api.Database.GetEnabledChatModelConfigs(chatdCtx) if err != nil { return userChatModelAvailability{}, err } + configuredProviders, err := api.configuredProvidersFromAIProviders(chatdCtx, enabledProviders) + if err != nil { + return userChatModelAvailability{}, err + } availability := userChatModelAvailability{ - configuredProviders: make([]chatprovider.ConfiguredProvider, 0, len(enabledProviders)), + configuredProviders: configuredProviders, configuredModels: make([]chatprovider.ConfiguredModel, 0, len(enabledModels)), enabledModels: enabledModels, enabledProviderNames: make(map[string]struct{}, len(enabledProviders)), + enabledProviderIDs: make(map[uuid.UUID]struct{}, len(enabledProviders)), + providerStatusByID: make(map[uuid.UUID]chatprovider.ProviderAvailability, len(enabledProviders)), } - for _, provider := range enabledProviders { - availability.configuredProviders = append( - availability.configuredProviders, - chatprovider.ConfiguredProvider{ - ProviderID: provider.ID, - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - }, - ) - normalizedProvider := chatprovider.NormalizeProvider(provider.Provider) + for _, configuredProvider := range configuredProviders { + normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) if normalizedProvider != "" { availability.enabledProviderNames[normalizedProvider] = struct{}{} } + if configuredProvider.ProviderID != uuid.Nil { + availability.enabledProviderIDs[configuredProvider.ProviderID] = struct{}{} + } } + userKeys := []chatprovider.UserProviderKey{} + if api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() { + userKeyRows, err := api.Database.GetUserAIProviderKeysByUserID(ctx, userID) + if err != nil { + return userChatModelAvailability{}, err + } + userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) + for _, userKey := range userKeyRows { + userKeys = append(userKeys, chatprovider.UserProviderKey{ + ChatProviderID: userKey.AIProviderID, + APIKey: userKey.APIKey, + }) + } + } + + fallbackKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues) + mergeProviderStatus := func( + statuses map[string]chatprovider.ProviderAvailability, + normalizedProvider string, + status chatprovider.ProviderAvailability, + ) { + current, ok := statuses[normalizedProvider] + if !ok || (!current.Available && status.Available) { + statuses[normalizedProvider] = status + } + } + + providerStatusByType := make(map[string]chatprovider.ProviderAvailability, len(availability.configuredProviders)) + for _, configuredProvider := range availability.configuredProviders { + normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) + if normalizedProvider == "" { + continue + } + _, providerStatus := chatprovider.ResolveUserProviderKeys( + fallbackKeys, + []chatprovider.ConfiguredProvider{configuredProvider}, + userKeys, + ) + status, ok := providerStatus[normalizedProvider] + if !ok { + continue + } + if configuredProvider.ProviderID != uuid.Nil { + availability.providerStatusByID[configuredProvider.ProviderID] = status + } + mergeProviderStatus(providerStatusByType, normalizedProvider, status) + } + + modelStatusByType := make(map[string]chatprovider.ProviderAvailability, len(enabledModels)) for _, model := range enabledModels { + normalizedProvider := chatprovider.NormalizeProvider(model.Provider) + if normalizedProvider == "" { + continue + } + if model.AIProviderID.Valid { + status, ok := availability.providerStatusByID[model.AIProviderID.UUID] + if ok { + mergeProviderStatus(modelStatusByType, normalizedProvider, status) + } + continue + } + if status, ok := providerStatusByType[normalizedProvider]; ok { + mergeProviderStatus(modelStatusByType, normalizedProvider, status) + } + } + availability.providerStatus = providerStatusByType + for provider, status := range modelStatusByType { + availability.providerStatus[provider] = status + } + + for _, model := range enabledModels { + normalizedProvider := chatprovider.NormalizeProvider(model.Provider) + if model.AIProviderID.Valid { + status, ok := availability.providerStatusByID[model.AIProviderID.UUID] + if !ok { + continue + } + if aggregateStatus, ok := availability.providerStatus[normalizedProvider]; ok && aggregateStatus.Available && !status.Available { + continue + } + } availability.configuredModels = append(availability.configuredModels, chatprovider.ConfiguredModel{ Provider: model.Provider, Model: model.Model, DisplayName: model.DisplayName, }) } - - userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, userID) - if err != nil { - return userChatModelAvailability{}, err - } - userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) - for _, userKey := range userKeyRows { - userKeys = append(userKeys, chatprovider.UserProviderKey{ - ChatProviderID: userKey.ChatProviderID, - APIKey: userKey.APIKey, - }) - } - - _, availability.providerStatus = chatprovider.ResolveUserProviderKeys( - ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), - availability.configuredProviders, - userKeys, - ) return availability, nil } @@ -870,6 +925,20 @@ func (api *API) userCanUseChatModelConfig( if err != nil { return chatModelConfigAvailable, err } + if model.AIProviderID.Valid { + providerID := model.AIProviderID.UUID + if _, ok := availability.enabledProviderIDs[providerID]; !ok { + return chatModelConfigUnavailableProviderDisabled, nil + } + providerStatus, ok := availability.providerStatusByID[providerID] + if !ok { + return chatModelConfigUnavailableProviderDisabled, nil + } + if !providerStatus.Available { + return chatModelConfigUnavailableCredentialsMissing, nil + } + return chatModelConfigAvailable, nil + } provider, _, err := chatprovider.ResolveModelWithProviderHint(model.Model, model.Provider) if err != nil { return chatModelConfigUnavailableProviderDisabled, nil @@ -6569,617 +6638,81 @@ func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusNoContent, nil) } -func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - //nolint:gocritic // System context required to read enabled chat providers. - systemCtx := dbauthz.AsSystemRestricted(ctx) - if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return +func (api *API) configuredProvidersFromAIProviders(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) { + if len(providers) == 0 { + return nil, nil } - - providers, err := api.Database.GetChatProviders(ctx) + providerIDs := make([]uuid.UUID, 0, len(providers)) + for _, provider := range providers { + providerIDs = append(providerIDs, provider.ID) + } + keys, err := api.Database.GetAIProviderKeysByProviderIDs(ctx, providerIDs) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list chat providers.", - Detail: err.Error(), - }) - return + return nil, xerrors.Errorf("get AI provider keys: %w", err) + } + keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers)) + for _, key := range keys { + keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key) } - - providersByName := make(map[string]database.ChatProvider, len(providers)) configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) for _, provider := range providers { - normalizedProvider := normalizeChatProvider(provider.Provider) - if normalizedProvider == "" { - continue - } - provider.Provider = normalizedProvider - providersByName[normalizedProvider] = provider - configuredProviders = append(configuredProviders, chatprovider.ConfiguredProvider{ - Provider: normalizedProvider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }) + configuredProviders = append(configuredProviders, api.configuredProviderFromAIProviderKeys(provider, keysByProviderID[provider.ID])) } - if api.chatDaemon == nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Chat processor is unavailable.", - Detail: "Chat processor is not configured.", - }) - return - } - - enabledProviders, err := api.Database.GetEnabledChatProviders( - systemCtx, - ) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to resolve provider API keys.", - Detail: err.Error(), - }) - return - } - - enabledConfiguredProviders := make( - []chatprovider.ConfiguredProvider, 0, len(enabledProviders), - ) - for _, provider := range enabledProviders { - normalizedProvider := normalizeChatProvider(provider.Provider) - if normalizedProvider == "" { - continue - } - enabledConfiguredProviders = append( - enabledConfiguredProviders, chatprovider.ConfiguredProvider{ - Provider: normalizedProvider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }, - ) - } - - effectiveKeys := chatprovider.MergeProviderAPIKeys( - ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), - enabledConfiguredProviders, - ) - effectiveKeys = chatprovider.MergeProviderAPIKeys( - effectiveKeys, configuredProviders, - ) - - supportedProviders := chatprovider.SupportedProviders() - resp := make([]codersdk.ChatProviderConfig, 0, len(supportedProviders)) - for _, provider := range supportedProviders { - configured, ok := providersByName[provider] - if ok { - resp = append( - resp, - convertChatProviderConfig( - configured, - api.hasEffectiveProviderAPIKey(ctx, configured), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) - continue - } - - source := codersdk.ChatProviderConfigSourceSupported - hasAPIKey := effectiveKeys.APIKey(provider) != "" - enabled := false - if chatprovider.IsEnvPresetProvider(provider) && hasAPIKey { - source = codersdk.ChatProviderConfigSourceEnvPreset - enabled = true - } - - resp = append(resp, codersdk.ChatProviderConfig{ - ID: uuid.Nil, - Provider: provider, - DisplayName: chatprovider.ProviderDisplayName(provider), - Enabled: enabled, - HasAPIKey: hasAPIKey, - CentralAPIKeyEnabled: true, - AllowUserAPIKey: false, - AllowCentralAPIKeyFallback: false, - BaseURL: effectiveKeys.BaseURL(provider), - Source: source, - }) - } - - httpapi.Write(ctx, rw, http.StatusOK, resp) + return configuredProviders, nil } -func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - apiKey := httpmw.APIKey(r) - var inserted database.ChatProvider - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - var req codersdk.CreateChatProviderConfigRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - provider := normalizeChatProvider(req.Provider) - if provider == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: chatProviderValidationDetail(), - }) - return - } - - if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - - enabled := true - if req.Enabled != nil { - enabled = *req.Enabled - } - baseURL, err := normalizeChatProviderBaseURL(req.BaseURL) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider base URL.", - Detail: err.Error(), - }) - return - } - - centralAPIKeyEnabled := true - if req.CentralAPIKeyEnabled != nil { - centralAPIKeyEnabled = *req.CentralAPIKeyEnabled - } - allowUserAPIKey := false - if req.AllowUserAPIKey != nil { - allowUserAPIKey = *req.AllowUserAPIKey - } - allowCentralAPIKeyFallback := false - if req.AllowCentralAPIKeyFallback != nil { - allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback - } - - if err := validateChatProviderCredentialPolicy( - centralAPIKeyEnabled, - allowUserAPIKey, - allowCentralAPIKeyFallback, - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid credential policy.", - Detail: err.Error(), - }) - return - } - - if err := validateChatProviderCentralAPIKey( - provider, - centralAPIKeyEnabled, - api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ - Provider: provider, - APIKey: strings.TrimSpace(req.APIKey), - BaseUrl: baseURL, - CentralApiKeyEnabled: centralAPIKeyEnabled, - }, uuid.Nil), - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: err.Error(), - }) - return - } - - inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: provider, - DisplayName: strings.TrimSpace(req.DisplayName), - APIKey: strings.TrimSpace(req.APIKey), - BaseUrl: baseURL, - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil}, - Enabled: enabled, - CentralApiKeyEnabled: centralAPIKeyEnabled, - AllowUserApiKey: allowUserAPIKey, - AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, - }) - if err != nil { - switch { - case database.IsUniqueViolation(err): - httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ - Message: "Chat provider already exists.", - Detail: err.Error(), - }) - return - case database.IsCheckViolation(err): - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: err.Error(), - }) - return - default: - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create chat provider.", - Detail: err.Error(), - }) - return +func (api *API) configuredProviderFromAIProviderKeys(provider database.AIProvider, keys []database.AIProviderKey) chatprovider.ConfiguredProvider { + apiKey := "" + for _, key := range keys { + if key.APIKey != "" { + apiKey = key.APIKey + break } } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - httpapi.Write( - ctx, - rw, - http.StatusCreated, - convertChatProviderConfig( - inserted, - api.hasEffectiveProviderAPIKey(ctx, inserted), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) -} - -func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - var ( - existing database.ChatProvider - updated database.ChatProvider - ) - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - existing, err := api.Database.GetChatProviderByID(ctx, providerID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get chat provider.", - Detail: err.Error(), - }) - return - } - - var req codersdk.UpdateChatProviderConfigRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - displayName := existing.DisplayName - if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" { - displayName = trimmed - } - - enabled := existing.Enabled - if req.Enabled != nil { - enabled = *req.Enabled - } - - apiKey := existing.APIKey - apiKeyKeyID := existing.ApiKeyKeyID - if req.APIKey != nil { - trimmedAPIKey := strings.TrimSpace(*req.APIKey) - if trimmedAPIKey != "" { - if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - } - apiKey = trimmedAPIKey - apiKeyKeyID = sql.NullString{} - } - baseURL := existing.BaseUrl - if req.BaseURL != nil { - baseURL, err = normalizeChatProviderBaseURL(*req.BaseURL) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider base URL.", - Detail: err.Error(), - }) - return - } - } - - centralAPIKeyEnabled := existing.CentralApiKeyEnabled - if req.CentralAPIKeyEnabled != nil { - centralAPIKeyEnabled = *req.CentralAPIKeyEnabled - } - allowUserAPIKey := existing.AllowUserApiKey - if req.AllowUserAPIKey != nil { - allowUserAPIKey = *req.AllowUserAPIKey - } - allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback - if req.AllowCentralAPIKeyFallback != nil { - allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback - } - - if err := validateChatProviderCredentialPolicy( - centralAPIKeyEnabled, - allowUserAPIKey, - allowCentralAPIKeyFallback, - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid credential policy.", - Detail: err.Error(), - }) - return - } - - if err := validateChatProviderCentralAPIKey( - existing.Provider, - centralAPIKeyEnabled, - api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ - ID: existing.ID, - Provider: existing.Provider, - APIKey: apiKey, - BaseUrl: baseURL, - CentralApiKeyEnabled: centralAPIKeyEnabled, - }, existing.ID), - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: err.Error(), - }) - return - } - - updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: displayName, + return chatprovider.ConfiguredProvider{ + ProviderID: provider.ID, + Provider: string(provider.Type), APIKey: apiKey, - BaseUrl: baseURL, - ApiKeyKeyID: apiKeyKeyID, - Enabled: enabled, - CentralApiKeyEnabled: centralAPIKeyEnabled, - AllowUserApiKey: allowUserAPIKey, - AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, - ID: existing.ID, + BaseURL: provider.BaseUrl, + CentralAPIKeyEnabled: true, + AllowUserAPIKey: api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), + AllowCentralAPIKeyFallback: true, + } +} + +func writeLegacyChatProviderGone(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(r.Context(), rw, http.StatusGone, codersdk.Response{ + Message: "Legacy chat provider APIs were removed. Use AI provider APIs instead.", + Detail: "See https://coder.com/docs/ai-coder/agents/models#providers for AI provider configuration.", }) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to update chat provider.", - Detail: err.Error(), - }) - return - } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - httpapi.Write( - ctx, - rw, - http.StatusOK, - convertChatProviderConfig( - updated, - api.hasEffectiveProviderAPIKey(ctx, updated), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) } -func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - err := api.Database.InTx(func(tx database.Store) error { - provider, err := tx.GetChatProviderByIDForUpdate(ctx, providerID) - switch { - case err == nil: - if err := tx.DeleteChatModelConfigsByProvider(ctx, provider.Provider); err != nil { - return xerrors.Errorf("soft delete chat model configs for provider %q: %w", provider.Provider, err) - } - if err := ensureDefaultChatModelConfig(ctx, tx); err != nil { - return err - } - if err := tx.DeleteChatProviderByID(ctx, provider.ID); err != nil { - return xerrors.Errorf("delete chat provider %s: %w", provider.ID, err) - } - return nil - case xerrors.Is(err, sql.ErrNoRows): - return err - default: - return xerrors.Errorf("get chat provider %s for delete: %w", providerID, err) - } - }, nil) - if err != nil { - if xerrors.Is(err, sql.ErrNoRows) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to delete chat provider.", - Detail: err.Error(), - }) - return - } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - rw.WriteHeader(http.StatusNoContent) +func (*API) listChatProviders(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) - - //nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials. - chatdCtx := dbauthz.AsChatd(ctx) - providers, err := api.Database.GetChatProviders(chatdCtx) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list chat providers.", - Detail: err.Error(), - }) - return - } - - userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list user chat provider keys.", - Detail: err.Error(), - }) - return - } - - hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys)) - for _, userKey := range userKeys { - hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true - } - - resp := make([]codersdk.UserChatProviderConfig, 0, len(providers)) - for _, provider := range providers { - if !provider.Enabled || !provider.AllowUserApiKey { - continue - } - hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID] - hasCentralAPIKeyFallback := provider.Enabled && - provider.AllowCentralApiKeyFallback && - api.hasEffectiveCentralProviderCredentials(ctx, provider, uuid.Nil) - resp = append( - resp, - convertUserChatProviderConfig( - provider, - hasUserAPIKey, - hasCentralAPIKeyFallback, - api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), - ), - ) - } - - httpapi.Write(ctx, rw, http.StatusOK, resp) +func (*API) createChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - //nolint:gocritic // Non-admin users need to validate provider availability before storing their own key. - provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get chat provider.", - Detail: err.Error(), - }) - return - } - if !provider.Enabled { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Provider is disabled.", - }) - return - } - if !provider.AllowUserApiKey { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Provider does not allow user API keys.", - }) - return - } - - var req codersdk.CreateUserChatProviderKeyRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - trimmedAPIKey := strings.TrimSpace(req.APIKey) - if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - if trimmedAPIKey == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key is required.", - }) - return - } - - if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: apiKey.UserID, - ChatProviderID: providerID, - APIKey: trimmedAPIKey, - ApiKeyKeyID: sql.NullString{}, - }); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to save user chat provider key.", - Detail: err.Error(), - }) - return - } - - hasCentralAPIKeyFallback := provider.Enabled && - provider.AllowCentralApiKeyFallback && - api.hasEffectiveCentralProviderCredentials(ctx, provider, uuid.Nil) - httpapi.Write( - ctx, - rw, - http.StatusOK, - convertUserChatProviderConfig( - provider, - true, - hasCentralAPIKeyFallback, - api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), - ), - ) +func (*API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) +func (*API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) +} - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } +func (*API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) +} - if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{ - UserID: apiKey.UserID, - ChatProviderID: providerID, - }); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to delete user chat provider key.", - Detail: err.Error(), - }) - return - } +func (*API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) +} - rw.WriteHeader(http.StatusNoContent) +func (*API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { @@ -7196,7 +6729,7 @@ func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { configs, err = api.Database.GetChatModelConfigs(ctx) } else { //nolint:gocritic // All authenticated users need to read enabled model configs to use the chat feature. - configs, err = api.Database.GetEnabledChatModelConfigs(dbauthz.AsSystemRestricted(ctx)) + configs, err = api.Database.GetEnabledChatModelConfigs(dbauthz.AsChatd(ctx)) } if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -7423,6 +6956,18 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return } + if strings.TrimSpace(req.Provider) != "" && req.AIProviderID == nil { + requestedProvider := chatprovider.NormalizeProvider(req.Provider) + if requestedProvider == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid provider."}) + return + } + if requestedProvider != existing.Provider { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required when updating provider."}) + return + } + } + provider := existing.Provider aiProviderID := existing.AIProviderID if req.AIProviderID != nil { @@ -7445,19 +6990,6 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { } provider = string(aiProvider.Type) aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true} - } else if strings.TrimSpace(req.Provider) != "" { - requestedProvider := normalizeChatProvider(req.Provider) - if requestedProvider == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: chatProviderValidationDetail(), - }) - return - } - provider = requestedProvider - if requestedProvider != existing.Provider { - aiProviderID = uuid.NullUUID{} - } } model := existing.Model @@ -7544,10 +7076,6 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return errChatProviderNotConfigured } updateParams.Provider = string(aiProvider.Type) - } else if !updateParams.AIProviderID.Valid { - if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil { - return err - } } setAsDefault := updateParams.IsDefault && !existing.IsDefault @@ -7777,18 +7305,6 @@ func parseChatUsageLimitUserID(rw http.ResponseWriter, r *http.Request) (uuid.UU return userID, true } -func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { - providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig")) - if err != nil { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid chat provider ID.", - Detail: err.Error(), - }) - return uuid.Nil, false - } - return providerID, true -} - func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { modelConfigID, err := uuid.Parse(chi.URLParam(r, "modelConfig")) if err != nil { @@ -7801,53 +7317,6 @@ func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, return modelConfigID, true } -func convertChatProviderConfig( - provider database.ChatProvider, - hasAPIKey bool, - source codersdk.ChatProviderConfigSource, -) codersdk.ChatProviderConfig { - displayName := strings.TrimSpace(provider.DisplayName) - if displayName == "" { - displayName = chatprovider.ProviderDisplayName(provider.Provider) - } - - return codersdk.ChatProviderConfig{ - ID: provider.ID, - Provider: provider.Provider, - DisplayName: displayName, - Enabled: provider.Enabled, - HasAPIKey: hasAPIKey, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - BaseURL: strings.TrimSpace(provider.BaseUrl), - Source: source, - CreatedAt: provider.CreatedAt, - UpdatedAt: provider.UpdatedAt, - } -} - -func convertUserChatProviderConfig( - provider database.ChatProvider, - hasUserAPIKey bool, - hasCentralAPIKeyFallback bool, - byokEnabled bool, -) codersdk.UserChatProviderConfig { - displayName := strings.TrimSpace(provider.DisplayName) - if displayName == "" { - displayName = chatprovider.ProviderDisplayName(provider.Provider) - } - - return codersdk.UserChatProviderConfig{ - ProviderID: provider.ID, - Provider: provider.Provider, - DisplayName: displayName, - HasUserAPIKey: hasUserAPIKey, - HasCentralAPIKeyFallback: hasCentralAPIKeyFallback, - BYOKEnabled: byokEnabled, - } -} - func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig { var aiProviderID *uuid.UUID if config.AIProviderID.Valid { @@ -7981,57 +7450,6 @@ func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) options.Vercel == nil } -func normalizeChatProvider(provider string) string { - return chatprovider.NormalizeProvider(provider) -} - -func normalizeChatProviderBaseURL(raw string) (string, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "", nil - } - - parsed, err := url.Parse(trimmed) - if err != nil { - return "", err - } - if parsed.Scheme == "" || parsed.Host == "" { - return "", xerrors.New("Base URL must be an absolute URL with scheme and host.") - } - if parsed.Scheme != "http" && parsed.Scheme != "https" { - return "", xerrors.New("Base URL scheme must be http or https.") - } - return parsed.String(), nil -} - -func chatProviderValidationDetail() string { - return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "." -} - -var ( - errChatModelConfigNotFound = xerrors.New("chat model config not found") - errChatProviderNotConfigured = xerrors.New("chat provider is not configured") -) - -// requireChatProviderForModelConfig takes a FOR UPDATE lock on the provider -// row to serialize model-config writes with deleteChatProvider. Do not swap -// this call for the non-locking provider lookup. -func requireChatProviderForModelConfig( - ctx context.Context, - tx database.Store, - provider string, -) error { - _, err := tx.GetChatProviderByProviderForUpdate(ctx, provider) - switch { - case err == nil: - return nil - case xerrors.Is(err, sql.ErrNoRows): - return errChatProviderNotConfigured - default: - return xerrors.Errorf("get chat provider %q: %w", provider, err) - } -} - const maxChatProviderAPIKeySize = 10240 // 10 KB func validateChatProviderAPIKeySize(apiKey string) error { @@ -8041,42 +7459,10 @@ func validateChatProviderAPIKeySize(apiKey string) error { return nil } -//nolint:revive // This helper validates the explicit credential policy tuple. -func validateChatProviderCredentialPolicy( - centralEnabled, allowUserKey, allowFallback bool, -) error { - if !centralEnabled && !allowUserKey { - return xerrors.New( - "At least one credential source must be enabled: central API key or user API key.", - ) - } - if allowFallback && !centralEnabled { - return xerrors.New( - "Central API key fallback requires central API key to be enabled.", - ) - } - if allowFallback && !allowUserKey { - return xerrors.New( - "Central API key fallback requires user API key to be enabled.", - ) - } - return nil -} - -//nolint:revive // This helper validates central-key requirements. -func validateChatProviderCentralAPIKey( - provider string, - centralEnabled bool, - hasCentralAPIKey bool, -) error { - if !centralEnabled || hasCentralAPIKey { - return nil - } - if chatprovider.ProviderAllowsAmbientCredentials(provider) { - return nil - } - return xerrors.New("API key is required when central API key is enabled.") -} +var ( + errChatModelConfigNotFound = xerrors.New("chat model config not found") + errChatProviderNotConfigured = xerrors.New("chat provider is not configured") +) // ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat // provider API keys. @@ -8089,77 +7475,6 @@ func ChatProviderAPIKeysFromDeploymentValues( return chatprovider.ProviderAPIKeys{} } -func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool { - return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil) -} - -func (api *API) hasEffectiveCentralProviderCredentials( - ctx context.Context, - provider database.ChatProvider, - excludeProviderID uuid.UUID, -) bool { - if api.hasEffectiveCentralProviderAPIKey(ctx, provider, excludeProviderID) { - return true - } - return provider.CentralApiKeyEnabled && - chatprovider.ProviderAllowsAmbientCredentials(provider.Provider) -} - -func (api *API) hasEffectiveCentralProviderAPIKey( - ctx context.Context, - provider database.ChatProvider, - excludeProviderID uuid.UUID, -) bool { - if !provider.CentralApiKeyEnabled { - return false - } - if strings.TrimSpace(provider.APIKey) != "" { - return true - } - deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues) - if deploymentKeys.APIKey(provider.Provider) != "" { - return true - } - if api.chatDaemon == nil { - return false - } - //nolint:gocritic // System context required to read enabled chat providers. - systemCtx := dbauthz.AsSystemRestricted(ctx) - - enabledProviders, err := api.Database.GetEnabledChatProviders( - systemCtx, - ) - if err != nil { - api.Logger.Warn(ctx, "failed to resolve provider API keys", - slog.F("provider", provider.Provider), - slog.Error(err), - ) - return false - } - - enabledConfiguredProviders := make( - []chatprovider.ConfiguredProvider, 0, len(enabledProviders), - ) - for _, configured := range enabledProviders { - if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID { - continue - } - enabledConfiguredProviders = append( - enabledConfiguredProviders, chatprovider.ConfiguredProvider{ - Provider: configured.Provider, - APIKey: configured.APIKey, - BaseURL: configured.BaseUrl, - }, - ) - } - - effectiveKeys := chatprovider.MergeProviderAPIKeys( - deploymentKeys, - enabledConfiguredProviders, - ) - return effectiveKeys.APIKey(provider.Provider) != "" -} - // @Summary Get PR insights // @ID get-pr-insights // @Security CoderSessionToken diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 2111742a11..07f5581d7f 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -890,26 +890,6 @@ func TestPostChats_ClientType(t *testing.T) { func TestListChats(t *testing.T) { t.Parallel() - sortedChatIDs := func(ids []uuid.UUID) []uuid.UUID { - out := append([]uuid.UUID(nil), ids...) - slices.SortFunc(out, func(a, b uuid.UUID) int { - return strings.Compare(a.String(), b.String()) - }) - return out - } - requireRootIDs := func(t *testing.T, chats []codersdk.Chat, want ...uuid.UUID) []uuid.UUID { - t.Helper() - - got := make([]uuid.UUID, 0, len(chats)) - for _, chat := range chats { - require.Nil(t, chat.ParentChatID, "list should only return root chats") - got = append(got, chat.ID) - } - - require.Equal(t, sortedChatIDs(want), sortedChatIDs(got)) - return got - } - t.Run("Success", func(t *testing.T) { t.Parallel() @@ -1559,171 +1539,6 @@ func TestListChats(t *testing.T) { require.Equal(t, archivedWithPR.ID, chats[0].ID) }) }) - - t.Run("TitleSearch", func(t *testing.T) { - t.Parallel() - - client, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, client.Client) - modelConfig := createChatModelConfig(t, client) - - // Verify that the title: filter is wired through the endpoint. - // Exhaustive ILIKE behavior is tested in TestGetChatsFilter (Title/* subtests). - alpha := dbgen.Chat(t, db, database.Chat{ - OrganizationID: firstUser.OrganizationID, - OwnerID: firstUser.UserID, - LastModelConfigID: modelConfig.ID, - Title: "alpha project", - }) - _ = dbgen.Chat(t, db, database.Chat{ - OrganizationID: firstUser.OrganizationID, - OwnerID: firstUser.UserID, - LastModelConfigID: modelConfig.ID, - Title: "beta unrelated", - }) - - ctx := testutil.Context(t, testutil.WaitLong) - - t.Run("SingleWord", func(t *testing.T) { - chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "title:alpha"}) - require.NoError(t, err) - requireRootIDs(t, chats, alpha.ID) - }) - - t.Run("MultiWord", func(t *testing.T) { - chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: `title:"alpha project"`}) - require.NoError(t, err) - requireRootIDs(t, chats, alpha.ID) - }) - - t.Run("BareTermsRejected", func(t *testing.T) { - _, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "bare words"}) - requireSDKError(t, err, http.StatusBadRequest) - }) - }) - - t.Run("PRStatusFilter", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - user := coderdtest.CreateFirstUser(t, client.Client) - modelConfig := createChatModelConfig(t, client) - - // Verify that pr_status filter is wired through the endpoint. - // Exhaustive query logic is tested in TestGetChatsFilter (PRStatus/* subtests). - createChatWithPR := func(title, prURL, prState string, prDraft bool) database.Chat { - t.Helper() - - chat := dbgen.Chat(t, db, database.Chat{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - LastModelConfigID: modelConfig.ID, - Title: title, - Status: database.ChatStatusCompleted, - }) - refreshedAt := time.Now().UTC().Truncate(time.Second) - staleAt := refreshedAt.Add(time.Hour) - _, err := db.UpsertChatDiffStatusReference( - dbauthz.AsSystemRestricted(ctx), - database.UpsertChatDiffStatusReferenceParams{ - ChatID: chat.ID, - Url: sql.NullString{String: prURL, Valid: true}, - GitBranch: "feature/test", - GitRemoteOrigin: "git@github.com:coder/coder.git", - StaleAt: staleAt, - }, - ) - require.NoError(t, err) - _, err = db.UpsertChatDiffStatus( - dbauthz.AsSystemRestricted(ctx), - database.UpsertChatDiffStatusParams{ - ChatID: chat.ID, - Url: sql.NullString{String: prURL, Valid: true}, - PullRequestState: sql.NullString{String: prState, Valid: true}, - PullRequestDraft: prDraft, - RefreshedAt: refreshedAt, - StaleAt: staleAt, - }, - ) - require.NoError(t, err) - return chat - } - - draftChat := createChatWithPR("draft pr", "https://github.com/coder/coder/pull/301", "open", true) - _ = createChatWithPR("open pr", "https://github.com/coder/coder/pull/302", "open", false) - - t.Run("MatchesDraft", func(t *testing.T) { - chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:draft"}) - require.NoError(t, err) - requireRootIDs(t, chats, draftChat.ID) - }) - - t.Run("InvalidPRStatus", func(t *testing.T) { - _, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "pr_status:bogus"}) - requireSDKError(t, err, http.StatusBadRequest) - }) - }) - - t.Run("UnreadFilter", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - user := coderdtest.CreateFirstUser(t, client.Client) - modelConfig := createChatModelConfig(t, client) - - // Verify that has_unread:true filter is wired through the endpoint. - // Exhaustive query logic is tested in TestGetChatsFilter (Unread/* subtests). - unreadChat := dbgen.Chat(t, db, database.Chat{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - LastModelConfigID: modelConfig.ID, - Title: "unread chat", - Status: database.ChatStatusCompleted, - }) - _, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ - ChatID: unreadChat.ID, - CreatedBy: []uuid.UUID{user.UserID}, - ModelConfigID: []uuid.UUID{modelConfig.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, - Content: []string{`[{"type":"text","text":"hello"}]`}, - ContentVersion: []int16{0}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - // Create a second chat with NO unread messages to prove filtering works. - _ = dbgen.Chat(t, db, database.Chat{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - LastModelConfigID: modelConfig.ID, - Title: "read chat", - Status: database.ChatStatusCompleted, - }) - - t.Run("MatchesUnread", func(t *testing.T) { - chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:true"}) - require.NoError(t, err) - requireRootIDs(t, chats, unreadChat.ID) - }) - - t.Run("InvalidHasUnread", func(t *testing.T) { - _, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "has_unread:bogus"}) - requireSDKError(t, err, http.StatusBadRequest) - }) - }) } func TestListChatModels(t *testing.T) { @@ -1802,18 +1617,12 @@ func TestListChatModels(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client.Client) providerType := database.AiProviderTypeAnthropic - chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: string(providerType), - CentralAPIKeyEnabled: ptr.Ref(false), - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) - aiProvider := createAIProviderForTest(t, client, string(providerType), "") + provider := createAIProviderForTest(t, client, string(providerType), "") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: string(providerType), - AIProviderID: &aiProvider.ID, + AIProviderID: &provider.ID, Model: "claude-sonnet", ContextLimit: &contextLimit, }) @@ -1833,7 +1642,7 @@ func TestListChatModels(t *testing.T) { require.False(t, anthropicProvider.Available) require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason) - _, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) @@ -1859,20 +1668,12 @@ func TestListChatModels(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "google", - APIKey: "central-api-key", - CentralAPIKeyEnabled: ptr.Ref(true), - AllowUserAPIKey: ptr.Ref(true), - AllowCentralAPIKeyFallback: ptr.Ref(true), - }) - require.NoError(t, err) - aiProvider := createAIProviderForTest(t, client, "google", "provider-api-key") + provider := createAIProviderForTest(t, client, "google", "provider-api-key") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "google", - AIProviderID: &aiProvider.ID, + AIProviderID: &provider.ID, Model: "gemini-1.5-pro", ContextLimit: &contextLimit, }) @@ -1891,7 +1692,7 @@ func TestListChatModels(t *testing.T) { require.NotNil(t, googleProvider) require.True(t, googleProvider.Available) - _, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) @@ -1919,17 +1720,12 @@ func TestListChatModels(t *testing.T) { client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) - chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-key", - }) - require.NoError(t, err) - aiProvider := createAIProviderForTest(t, client, "openai", "test-key") + provider := createAIProviderForTest(t, client, "openai", "test-key") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", - AIProviderID: &aiProvider.ID, + AIProviderID: &provider.ID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) @@ -1943,7 +1739,7 @@ func TestListChatModels(t *testing.T) { require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model) enabled := false - _, err = client.UpdateChatProvider(ctx, chatProvider.ID, codersdk.UpdateChatProviderConfigRequest{ + _, err = client.UpdateAIProvider(ctx, provider.ID.String(), codersdk.UpdateAIProviderRequest{ Enabled: &enabled, }) require.NoError(t, err) @@ -2450,6 +2246,7 @@ func TestUserAIProviderKeys(t *testing.T) { func TestListChatProviders(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -2519,6 +2316,7 @@ func TestListChatProviders(t *testing.T) { func TestCreateChatProvider(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -2819,6 +2617,7 @@ func TestCreateChatProvider(t *testing.T) { func TestUpdateChatProvider(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -3111,282 +2910,7 @@ func TestUpdateChatProvider(t *testing.T) { func TestDeleteChatProvider(t *testing.T) { t.Parallel() - - t.Run("Success", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) - - err = client.DeleteChatProvider(ctx, provider.ID) - require.NoError(t, err) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - for _, listed := range providers { - require.NotEqual(t, provider.ID, listed.ID) - } - }) - - t.Run("SuccessWithHistoricalChats", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, client.Client) - - providerToDelete, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "delete-api-key", - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) - aiProviderToDelete := createAIProviderForTest(t, client, providerToDelete.Provider, "delete-api-key") - - deleteContextLimit := int64(4096) - deleteIsDefault := true - configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: providerToDelete.Provider, - AIProviderID: &aiProviderToDelete.ID, - Model: "gpt-4o-delete-provider", - ContextLimit: &deleteContextLimit, - IsDefault: &deleteIsDefault, - }) - require.NoError(t, err) - - keepProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - APIKey: "keep-api-key", - }) - require.NoError(t, err) - keepAIProvider := createAIProviderForTest(t, client, keepProvider.Provider, "keep-api-key") - - keepContextLimit := int64(8192) - keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: keepProvider.Provider, - AIProviderID: &keepAIProvider.ID, - Model: "claude-keep-provider", - ContextLimit: &keepContextLimit, - }) - require.NoError(t, err) - - chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ - OrganizationID: firstUser.OrganizationID, - ModelConfigID: ptr.Ref(configToDelete.ID), - Content: []codersdk.ChatInputPart{{ - Type: codersdk.ChatInputPartTypeText, - Text: "provider delete history " + t.Name(), - }}, - }) - require.NoError(t, err) - require.Equal(t, configToDelete.ID, chat.LastModelConfigID) - - insertAssistantCostMessage(t, db, chat.ID, configToDelete.ID, 500) - - _, err = client.UpsertUserChatProviderKey(ctx, providerToDelete.ID, codersdk.CreateUserChatProviderKeyRequest{ - APIKey: "user-delete-key", - }) - require.NoError(t, err) - - userKeys, err := db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID) - require.NoError(t, err) - require.Len(t, userKeys, 1) - require.Equal(t, providerToDelete.ID, userKeys[0].ChatProviderID) - - err = client.DeleteChatProvider(ctx, providerToDelete.ID) - require.NoError(t, err) - - _, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), providerToDelete.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - foundKeepProvider := false - for _, listed := range providers { - require.NotEqual(t, providerToDelete.ID, listed.ID) - if listed.ID == keepProvider.ID { - foundKeepProvider = true - } - } - require.True(t, foundKeepProvider) - - configs, err := client.ListChatModelConfigs(ctx) - require.NoError(t, err) - foundDeletedConfig := false - foundKeepConfig := false - for _, config := range configs { - if config.ID == configToDelete.ID { - foundDeletedConfig = true - } - if config.ID == keepConfig.ID { - foundKeepConfig = true - require.True(t, config.IsDefault) - } - } - require.False(t, foundDeletedConfig) - require.True(t, foundKeepConfig) - - defaultConfig, err := db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx)) - require.NoError(t, err) - require.Equal(t, keepConfig.ID, defaultConfig.ID) - - _, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), configToDelete.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - gotChat, err := client.GetChat(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, chat.ID, gotChat.ID) - require.Equal(t, configToDelete.ID, gotChat.LastModelConfigID) - - messages, err := client.GetChatMessages(ctx, chat.ID, nil) - require.NoError(t, err) - foundHistoricalMessage := false - for _, message := range messages.Messages { - if message.ModelConfigID != nil && *message.ModelConfigID == configToDelete.ID { - foundHistoricalMessage = true - break - } - } - require.True(t, foundHistoricalMessage) - - userKeys, err = db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID) - require.NoError(t, err) - require.Empty(t, userKeys) - }) - - t.Run("SuccessWithHistoricalChatsAndNoReplacementConfig", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, client.Client) - - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "only-provider-api-key", - }) - require.NoError(t, err) - aiProvider := createAIProviderForTest(t, client, provider.Provider, "only-provider-api-key") - - contextLimit := int64(4096) - isDefault := true - config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, - AIProviderID: &aiProvider.ID, - Model: "gpt-4o-only-provider", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) - - chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ - OrganizationID: firstUser.OrganizationID, - ModelConfigID: ptr.Ref(config.ID), - Content: []codersdk.ChatInputPart{{ - Type: codersdk.ChatInputPartTypeText, - Text: "only provider delete history " + t.Name(), - }}, - }) - require.NoError(t, err) - require.Equal(t, config.ID, chat.LastModelConfigID) - - insertAssistantCostMessage(t, db, chat.ID, config.ID, 250) - - err = client.DeleteChatProvider(ctx, provider.ID) - require.NoError(t, err) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - for _, listed := range providers { - require.NotEqual(t, provider.ID, listed.ID) - } - - _, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), provider.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - _, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), config.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - _, err = db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx)) - require.ErrorIs(t, err, sql.ErrNoRows) - - configs, err := client.ListChatModelConfigs(ctx) - require.NoError(t, err) - require.Empty(t, configs) - - gotChat, err := client.GetChat(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, config.ID, gotChat.LastModelConfigID) - - messages, err := client.GetChatMessages(ctx, chat.ID, nil) - require.NoError(t, err) - foundHistoricalMessage := false - for _, message := range messages.Messages { - if message.ModelConfigID != nil && *message.ModelConfigID == config.ID { - foundHistoricalMessage = true - break - } - } - require.True(t, foundHistoricalMessage) - }) - - t.Run("NotFound", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - err := client.DeleteChatProvider(ctx, uuid.New()) - requireSDKError(t, err, http.StatusNotFound) - }) - - t.Run("InvalidProviderID", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - res, err := client.Request( - ctx, - http.MethodDelete, - "/api/experimental/chats/providers/not-a-uuid", - nil, - ) - require.NoError(t, err) - defer res.Body.Close() - - err = codersdk.ReadBodyAsError(res) - sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid chat provider ID.", sdkErr.Message) - }) - - t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - adminClient := newChatClient(t) - firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) - memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) - memberClient := codersdk.NewExperimentalClient(memberClientRaw) - - provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) - - err = memberClient.DeleteChatProvider(ctx, provider.ID) - requireSDKError(t, err, http.StatusForbidden) - }) + t.Skip("legacy chat provider API removed in favor of AI provider API") } func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { @@ -3414,6 +2938,7 @@ func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { func TestUserChatProviderConfigs(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig { t.Helper() @@ -3809,6 +3334,7 @@ func TestUserChatProviderConfigs(t *testing.T) { func TestUpsertUserChatProviderKey(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { t.Parallel() @@ -4263,6 +3789,26 @@ func TestUpdateChatModelConfig(t *testing.T) { requireChatModelPricing(t, configs[0].ModelConfig, pricing) }) + t.Run("UnchangedProviderWithoutAIProviderID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ + Provider: modelConfig.Provider, + Model: "gpt-4o-mini-updated", + }) + require.NoError(t, err) + require.Equal(t, modelConfig.ID, updated.ID) + require.Equal(t, modelConfig.Provider, updated.Provider) + require.NotNil(t, updated.AIProviderID) + require.Equal(t, *modelConfig.AIProviderID, *updated.AIProviderID) + require.Equal(t, "gpt-4o-mini-updated", updated.Model) + }) + t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) { t.Parallel() @@ -4482,11 +4028,12 @@ func TestUpdateChatModelConfig(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) + missingProviderID := uuid.New() _, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ - Provider: "anthropic", + AIProviderID: &missingProviderID, }) sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed) - require.Equal(t, "Chat provider is not configured.", sdkErr.Message) + require.Equal(t, "AI provider is not configured.", sdkErr.Message) }) t.Run("NotFoundWhenTargetRowDisappearsInTx", func(t *testing.T) { @@ -10953,32 +10500,27 @@ func enableUserChatProviderKey( adminClient *codersdk.ExperimentalClient, userClient *codersdk.ExperimentalClient, providerName string, -) codersdk.ChatProviderConfig { +) codersdk.AIProvider { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) - providers, err := adminClient.ListChatProviders(ctx) + providers, err := adminClient.AIProviders(ctx) require.NoError(t, err) - var provider codersdk.ChatProviderConfig + var provider codersdk.AIProvider for _, candidate := range providers { - if candidate.Provider == providerName && candidate.Source == codersdk.ChatProviderConfigSourceDatabase { + if candidate.Type == codersdk.AIProviderType(providerName) { provider = candidate break } } require.NotEqual(t, uuid.Nil, provider.ID) - updated, err := adminClient.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) - - _, err = userClient.UpsertUserChatProviderKey(ctx, updated.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = userClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "test-user-api-key-" + uuid.NewString(), }) require.NoError(t, err) - return updated + return provider } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. @@ -11792,13 +11334,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) { defaultModelConfig := createChatModelConfig(t, adminClient) provider := enableUserChatProviderKey(t, adminClient, memberClient, defaultModelConfig.Provider) - modelConfig := createAdditionalChatModelConfig( - t, - adminClient, - defaultModelConfig.Provider, - "gpt-4o-personal-"+uuid.NewString(), - ) - err := adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{ + modelProvider := createAIProviderForTest(t, adminClient, "anthropic", "") + _, err := memberClient.UpsertUserAIProviderKey(ctx, "me", modelProvider.ID, codersdk.CreateUserAIProviderKeyRequest{ + APIKey: "test-user-api-key-" + uuid.NewString(), + }) + require.NoError(t, err) + contextLimit := int64(4096) + modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &modelProvider.ID, + Model: "claude-personal-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + err = adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{ ModelConfigID: modelConfig.ID.String(), }) require.NoError(t, err) @@ -11813,19 +11362,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) { defaultModelConfig.Provider, "gpt-4o-personal-disabled-"+uuid.NewString(), ) - disabledProvider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - Enabled: ptr.Ref(false), - CentralAPIKeyEnabled: ptr.Ref(false), - AllowUserAPIKey: ptr.Ref(true), + disabledProvider := createAIProviderForTest(t, adminClient, "google", "test-api-key") + contextLimit = int64(4096) + disabledProviderModelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "google", + AIProviderID: &disabledProvider.ID, + Model: "gemini-personal-disabled-provider-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + enabled := false + disabledProvider, err = adminClient.UpdateAIProvider(ctx, disabledProvider.ID.String(), codersdk.UpdateAIProviderRequest{ + Enabled: &enabled, }) require.NoError(t, err) - disabledProviderModelConfig := createAdditionalChatModelConfig( - t, - adminClient, - "anthropic", - "claude-personal-disabled-provider-"+uuid.NewString(), - ) require.NotEqual(t, uuid.Nil, provider.ID) require.NotEqual(t, uuid.Nil, disabledProvider.ID) @@ -12153,12 +11703,19 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) { firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) defaultModel := createChatModelConfig(t, adminClient) _ = enableUserChatProviderKey(t, adminClient, adminClient, defaultModel.Provider) - overrideModel := createAdditionalChatModelConfig( - t, - adminClient, - defaultModel.Provider, - "gpt-4o-root-personal-"+uuid.NewString(), - ) + overrideProvider := createAIProviderForTest(t, adminClient, "anthropic", "") + _, err := adminClient.UpsertUserAIProviderKey(ctx, "me", overrideProvider.ID, codersdk.CreateUserAIProviderKeyRequest{ + APIKey: "test-user-api-key-" + uuid.NewString(), + }) + require.NoError(t, err) + contextLimit := int64(4096) + overrideModel, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &overrideProvider.ID, + Model: "claude-root-personal-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) disabledModel := createDisabledChatModelConfig( t, adminClient, @@ -12203,7 +11760,7 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) { require.NoError(t, err) } - err := adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{ + err = adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{ AllowUsers: true, }) require.NoError(t, err) diff --git a/coderd/workspaceagents_chat_context_internal_test.go b/coderd/workspaceagents_chat_context_internal_test.go index 5a2c8e25be..cf3811d64b 100644 --- a/coderd/workspaceagents_chat_context_internal_test.go +++ b/coderd/workspaceagents_chat_context_internal_test.go @@ -2,6 +2,7 @@ package coderd import ( "context" + "database/sql" "encoding/json" "testing" "time" @@ -96,17 +97,21 @@ func insertAgentChatTestModelConfig( createdBy := uuid.NullUUID{UUID: userID, Valid: true} - _ = dbgen.ChatProvider(t, db, database.ChatProvider{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-api-key", - CreatedBy: createdBy, + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-openai", + DisplayName: sql.NullString{String: "OpenAI", Valid: true}, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: "test-api-key", }) return dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - Provider: "openai", - CreatedBy: createdBy, - UpdatedBy: createdBy, - IsDefault: true, + Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + CreatedBy: createdBy, + UpdatedBy: createdBy, + IsDefault: true, }) } diff --git a/coderd/x/chatd/advisor_internal_test.go b/coderd/x/chatd/advisor_internal_test.go index ad81e580b2..48a937b3d8 100644 --- a/coderd/x/chatd/advisor_internal_test.go +++ b/coderd/x/chatd/advisor_internal_test.go @@ -24,8 +24,8 @@ import ( // advisorOverrideStubStore stubs only the database methods that // resolveAdvisorModelOverride exercises. The prod code calls -// GetEnabledChatModelConfigByID so the query joins chat_providers and -// filters both enabled flags atomically; tests simulate that by returning +// GetEnabledChatModelConfigByID so the query joins ai_providers and +// filters both enabled flags atomically. Tests simulate that by returning // configs the stub treats as enabled. type advisorOverrideStubStore struct { database.Store diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 3b4429a789..c8c3d8968f 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -350,13 +350,8 @@ func (p *Server) resolveAdvisorModelOverride( return fallbackModel, fallbackCallConfig } - // GetEnabledChatModelConfigByID checks the model config and referenced - // provider enabled state, so it returns sql.ErrNoRows the moment an - // admin disables either one. Using the cached ModelConfigByID here - // would keep resolving an override whose provider was just disabled, - // and an available fallback key would let ModelFromConfig succeed, - // silently routing advisor prompts to a provider the admin expects to - // be off. + // Re-read the override instead of using the cache so disabled models + // or providers stop routing advisor prompts immediately. overrideConfig, err := p.db.GetEnabledChatModelConfigByID( ctx, advisorCfg.ModelConfigID, @@ -7102,9 +7097,8 @@ func (p *Server) runChat( // Fire title generation asynchronously so it doesn't block the // chat response. It uses a detached context so it can finish // even after the chat processing context is canceled. - // Snapshot model, provider keys, logger, and ctx before launch; all four get - // reassigned below (model = cuModel, providerKeys = computerUseProviderKeys, - // logger = logger.With(...), ctx = runCtx) and the goroutine captures by reference. + // Snapshot values captured by the goroutine because model, providerKeys, + // logger, and ctx are reassigned below. titleModel := model titleProviderKeys := providerKeys titleLogger := logger @@ -8502,13 +8496,17 @@ func (p *Server) resolveChatModel( } func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { - if !provider.Enabled { - return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) - } keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID) if err != nil { return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err) } + return p.aiProviderConfigFromKeys(provider, keys) +} + +func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { + if !provider.Enabled { + return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) + } apiKey := "" // GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes // one provider-scoped key because runtime provider config has one API key slot. @@ -8529,6 +8527,48 @@ func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvi }, nil } +func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) { + if len(providers) == 0 { + return nil, nil + } + providerIDs := make([]uuid.UUID, 0, len(providers)) + for _, provider := range providers { + providerIDs = append(providerIDs, provider.ID) + } + keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs) + if err != nil { + return nil, xerrors.Errorf("get AI provider keys: %w", err) + } + keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers)) + for _, key := range keys { + keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key) + } + configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) + for _, provider := range providers { + configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID]) + if err != nil { + return nil, err + } + configuredProviders = append(configuredProviders, configuredProvider) + } + return configuredProviders, nil +} + +func ensureUniqueConfiguredProviderTypes(providers []chatprovider.ConfiguredProvider) error { + seen := make(map[string]uuid.UUID, len(providers)) + for _, provider := range providers { + normalizedProvider := chatprovider.NormalizeProvider(provider.Provider) + if normalizedProvider == "" { + continue + } + if existingProviderID, ok := seen[normalizedProvider]; ok && existingProviderID != provider.ProviderID { + return xerrors.Errorf("multiple enabled AI providers use provider type %q; select an AI provider by ID", normalizedProvider) + } + seen[normalizedProvider] = provider.ProviderID + } + return nil +} + func (p *Server) resolveUserProviderAPIKeysForProvider( ctx context.Context, ownerID uuid.UUID, @@ -8559,12 +8599,6 @@ func (p *Server) resolveUserProviderAPIKeysForProvider( []chatprovider.ConfiguredProvider{configuredProvider}, userKeys, ) - enabledProviders := map[string]struct{}{} - normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) - if normalizedProvider != "" { - enabledProviders[normalizedProvider] = struct{}{} - } - chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders) return keys, nil } @@ -8598,9 +8632,6 @@ func (p *Server) resolveUserProviderAPIKeys( ownerID uuid.UUID, selectedAIProviderID uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) { - var configuredProviders []chatprovider.ConfiguredProvider - userKeys := []chatprovider.UserProviderKey{} - if selectedAIProviderID != uuid.Nil { provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID) if err != nil { @@ -8608,48 +8639,35 @@ func (p *Server) resolveUserProviderAPIKeys( } return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) } + providers, err := p.configCache.EnabledProviders(ctx) if err != nil { return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( - "get enabled chat providers: %w", + "get enabled AI providers: %w", err, ) } - configuredProviders = make( - []chatprovider.ConfiguredProvider, 0, len(providers), - ) - for _, provider := range providers { - configuredProviders = append( - configuredProviders, chatprovider.ConfiguredProvider{ - ProviderID: provider.ID, - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - }, - ) + configuredProviders, err := p.aiProviderConfigs(ctx, providers) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err } - allowAnyUserAPIKey := false - for _, provider := range configuredProviders { - if provider.AllowUserAPIKey { - allowAnyUserAPIKey = true - break - } + if err := ensureUniqueConfiguredProviderTypes(configuredProviders); err != nil { + return chatprovider.ProviderAPIKeys{}, err } - if allowAnyUserAPIKey { - userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID) + + userKeys := []chatprovider.UserProviderKey{} + if p.allowBYOK { + userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID) if err != nil { return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( - "get user chat provider keys: %w", + "get user AI provider keys: %w", err, ) } userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) for _, userKey := range userKeyRows { userKeys = append(userKeys, chatprovider.UserProviderKey{ - ChatProviderID: userKey.ChatProviderID, + ChatProviderID: userKey.AIProviderID, APIKey: userKey.APIKey, }) } diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index cbdd7a21dc..b4d891bd06 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -797,12 +797,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, - APIKey: "test-key", - BaseUrl: serverURL, + providerID := uuid.New() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), @@ -961,12 +963,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, - APIKey: "test-key", - BaseUrl: serverURL, + providerID := uuid.New() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), @@ -1110,11 +1114,13 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { }, } - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "anthropic", - CentralApiKeyEnabled: true, - AllowCentralApiKeyFallback: true, + providerID := uuid.New() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil) keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) @@ -1185,10 +1191,13 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe }, } - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, + providerID := uuid.New() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil) keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) @@ -3922,7 +3931,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) { db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return( database.ChatModelConfig{}, xerrors.New("no model configured"), ).AnyTimes() - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( database.ChatUsageLimitConfig{}, sql.ErrNoRows, @@ -5696,7 +5705,7 @@ func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) { return database.ChatModelConfig{}, chatloop.ErrInterrupted }, ).AnyTimes() - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( database.ChatUsageLimitConfig{}, sql.ErrNoRows, diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index a2edfad93a..e9a0e8e1c2 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -6120,21 +6120,24 @@ func setOpenAIProviderBaseURL( ) { t.Helper() - provider, err := db.GetChatProviderByProvider(ctx, "openai") - require.NoError(t, err) - - _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - }) + providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) require.NoError(t, err) + for _, provider := range providers { + if provider.Type != database.AiProviderTypeOpenai { + continue + } + _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: provider.ID, + DisplayName: provider.DisplayName, + Enabled: provider.Enabled, + BaseUrl: baseURL, + Settings: provider.Settings, + SettingsKeyID: provider.SettingsKeyID, + }) + require.NoError(t, err) + return + } + require.Fail(t, "openai provider not found") } func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) { @@ -7193,10 +7196,11 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) { true, false, ) - _, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: user.ID, - ChatProviderID: provider.ID, - APIKey: userAPIKey, + _, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AIProviderID: provider.ID, + APIKey: userAPIKey, }) require.NoError(t, err) diff --git a/coderd/x/chatd/configcache.go b/coderd/x/chatd/configcache.go index e23509df8b..69470a9473 100644 --- a/coderd/x/chatd/configcache.go +++ b/coderd/x/chatd/configcache.go @@ -30,7 +30,7 @@ const ( ) type cachedProviders struct { - providers []database.ChatProvider + providers []database.AIProvider expiresAt time.Time } @@ -74,7 +74,7 @@ type chatConfigCache struct { // Providers (singleton). providers *cachedProviders providerGeneration uint64 - providerFetches singleflight.Group[string, []database.ChatProvider] + providerFetches singleflight.Group[string, []database.AIProvider] // Model configs (keyed by ID). modelTopologyEpoch uint64 @@ -131,7 +131,7 @@ func singleflightDoChan[K comparable, V any]( } } -func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.ChatProvider, error) { +func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.AIProvider, error) { if providers, ok := c.cachedProviders(); ok { return providers, nil } @@ -141,12 +141,12 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat ctx, &c.providerFetches, fmt.Sprintf("%d:providers", generation), - func() ([]database.ChatProvider, error) { + func() ([]database.AIProvider, error) { if cached, ok := c.cachedProviders(); ok { return cached, nil } - fetched, err := c.db.GetEnabledChatProviders(c.ctx) + fetched, err := c.db.GetAIProviders(c.ctx, database.GetAIProvidersParams{}) if err != nil { return nil, err } @@ -161,7 +161,7 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat return slices.Clone(providers), nil } -func (c *chatConfigCache) cachedProviders() ([]database.ChatProvider, bool) { +func (c *chatConfigCache) cachedProviders() ([]database.AIProvider, bool) { c.mu.RLock() entry := c.providers c.mu.RUnlock() @@ -188,7 +188,7 @@ func (c *chatConfigCache) providersGeneration() uint64 { return generation } -func (c *chatConfigCache) storeProviders(generation uint64, providers []database.ChatProvider) { +func (c *chatConfigCache) storeProviders(generation uint64, providers []database.AIProvider) { c.mu.Lock() defer c.mu.Unlock() diff --git a/coderd/x/chatd/configcache_test.go b/coderd/x/chatd/configcache_test.go index 8213cd5d9b..ee0855f2f8 100644 --- a/coderd/x/chatd/configcache_test.go +++ b/coderd/x/chatd/configcache_test.go @@ -22,7 +22,7 @@ import ( type stubChatConfigStore struct { database.Store - getEnabledChatProviders func(context.Context) ([]database.ChatProvider, error) + getAIProviders func(context.Context) ([]database.AIProvider, error) getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error) getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error) getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error) @@ -35,12 +35,12 @@ type stubChatConfigStore struct { advisorConfigCalls atomic.Int32 } -func (s *stubChatConfigStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { +func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) { s.enabledProvidersCalls.Add(1) - if s.getEnabledChatProviders == nil { - panic("unexpected GetEnabledChatProviders call") + if s.getAIProviders == nil { + panic("unexpected GetAIProviders call") } - return s.getEnabledChatProviders(ctx) + return s.getAIProviders(ctx) } func (s *stubChatConfigStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) { @@ -80,9 +80,9 @@ func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) - providers := []database.ChatProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testAIProvider("provider-a")} store := &stubChatConfigStore{ - getEnabledChatProviders: func(context.Context) ([]database.ChatProvider, error) { + getAIProviders: func(context.Context) ([]database.AIProvider, error) { return providers, nil }, } @@ -104,9 +104,9 @@ func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -126,9 +126,9 @@ func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -398,12 +398,12 @@ func TestConfigCache_Singleflight(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - providers := []database.ChatProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testAIProvider("provider-a")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { startedOnce.Do(func() { close(fetchStarted) }) <-releaseFetch return providers, nil @@ -411,7 +411,7 @@ func TestConfigCache_Singleflight(t *testing.T) { cache := newChatConfigCache(ctx, store, clock) const callers = 8 - results := make([][]database.ChatProvider, callers) + results := make([][]database.AIProvider, callers) errs := make([]error, callers) var wg sync.WaitGroup start := make(chan struct{}) @@ -441,13 +441,13 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - firstProviders := []database.ChatProvider{testChatProvider("provider-a")} - secondProviders := []database.ChatProvider{testChatProvider("provider-b")} + firstProviders := []database.AIProvider{testAIProvider("provider-a")} + secondProviders := []database.AIProvider{testAIProvider("provider-b")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() if call == 1 { startedOnce.Do(func() { close(fetchStarted) }) @@ -458,7 +458,7 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) { } cache := newChatConfigCache(ctx, store, clock) - resultCh := make(chan []database.ChatProvider, 1) + resultCh := make(chan []database.AIProvider, 1) errCh := make(chan error, 1) go func() { providers, err := cache.EnabledProviders(ctx) @@ -494,14 +494,14 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - staleProviders := []database.ChatProvider{testChatProvider("provider-stale")} - freshProviders := []database.ChatProvider{testChatProvider("provider-fresh")} + staleProviders := []database.AIProvider{testAIProvider("provider-stale")} + freshProviders := []database.AIProvider{testAIProvider("provider-fresh")} firstStarted := make(chan struct{}) secondStarted := make(chan struct{}) releaseFirst := make(chan struct{}) releaseSecond := make(chan struct{}) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { switch call := store.enabledProvidersCalls.Load(); call { case 1: close(firstStarted) @@ -518,7 +518,7 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing cache := newChatConfigCache(ctx, store, clock) type result struct { - providers []database.ChatProvider + providers []database.AIProvider err error } @@ -670,11 +670,12 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testi require.Equal(t, int32(2), store.modelConfigByIDCalls.Load()) } -func testChatProvider(name string) database.ChatProvider { - return database.ChatProvider{ +func testAIProvider(name string) database.AIProvider { + return database.AIProvider{ ID: uuid.New(), - Provider: name, - DisplayName: name, + Type: database.AIProviderType(name), + Name: name, + DisplayName: sql.NullString{String: name, Valid: true}, Enabled: true, CreatedAt: time.Unix(0, 0).UTC(), UpdatedAt: time.Unix(0, 0).UTC(), @@ -737,19 +738,19 @@ func TestConfigCache_CallerCancellation(t *testing.T) { name: "EnabledProviders", setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) select { case <-ctx.Done(): return nil, ctx.Err() case <-release: - return []database.ChatProvider{testChatProvider("p")}, nil + return []database.AIProvider{testAIProvider("p")}, nil } } }, setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) { + store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) <-ctx.Done() return nil, ctx.Err() diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 1f557f5eee..29d6fef9d2 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -154,12 +154,12 @@ func validateModelConfigAndResolveProvider( } func enabledProviderContainsName( - providers []database.ChatProvider, + providers []database.AIProvider, providerName string, ) bool { normalizedProviderName := chatprovider.NormalizeProvider(providerName) for _, provider := range providers { - if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName { + if chatprovider.NormalizeProvider(string(provider.Type)) == normalizedProviderName { return true } } diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 5a0b7fdc89..55254db1a2 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -2,6 +2,7 @@ package chatd import ( "context" + "database/sql" "encoding/json" "sync" "testing" @@ -183,13 +184,15 @@ func seedInternalChatDeps( UserID: user.ID, OrganizationID: org.ID, }) - dbgen.ChatProvider(t, db, database.ChatProvider{ + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", }) model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - IsDefault: true, + Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + IsDefault: true, }) return user, org, model @@ -309,24 +312,38 @@ func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) { require.True(t, keys.HasProvider("bedrock")) require.Empty(t, keys.APIKey("bedrock")) }) + + t.Run("RejectsAmbiguousProviderTypeWithoutSelectedProvider", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + ctx := chatdTestContext(t) + user, _, _ := seedInternalChatDeps(t, db) + insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "first-provider-api-key", true) + insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "second-provider-api-key", true) + + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, uuid.Nil) + require.ErrorContains(t, err, "multiple enabled AI providers use provider type") + require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) + }) } func TestResolveChatModel_AIProviderDisabled(t *testing.T) { t.Parallel() ctx := chatdTestContext(t) - db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t) + db, ps := dbtestutil.NewDB(t) user, org, _ := seedInternalChatDeps(t, db) provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", false) modelConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ Provider: "openai", Model: "gpt-4o-mini", + AIProviderID: uuid.NullUUID{ + UUID: provider.ID, + Valid: true, + }, }) - _, err := sqlDB.ExecContext(ctx, "UPDATE chat_model_configs SET ai_provider_id = $1 WHERE id = $2", provider.ID, modelConfig.ID) - require.NoError(t, err) - loadedModelConfig, err := db.GetChatModelConfigByID(ctx, modelConfig.ID) - require.NoError(t, err) - require.True(t, loadedModelConfig.AIProviderID.Valid) server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) chat := dbgen.Chat(t, db, database.Chat{ OrganizationID: org.ID, @@ -408,19 +425,20 @@ func insertInternalChatProvider( centralAPIKeyEnabled bool, allowUserAPIKey bool, allowCentralAPIKeyFallback bool, -) database.ChatProvider { +) database.AIProvider { t.Helper() - providerConfig := dbgen.ChatProvider(t, db, database.ChatProvider{ - Provider: provider, - DisplayName: provider, - CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, - }, func(p *database.InsertChatProviderParams) { - p.APIKey = apiKey - p.CentralApiKeyEnabled = centralAPIKeyEnabled - p.AllowUserApiKey = allowUserAPIKey - p.AllowCentralApiKeyFallback = allowCentralAPIKeyFallback + providerConfig := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(provider), + Name: "test-" + uuid.NewString(), + DisplayName: sql.NullString{String: provider, Valid: true}, }) + if apiKey != "" { + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: providerConfig.ID, + APIKey: apiKey, + }) + } return providerConfig } diff --git a/coderd/x/chatd/title_override_test.go b/coderd/x/chatd/title_override_test.go index de2227af96..942bea6ac8 100644 --- a/coderd/x/chatd/title_override_test.go +++ b/coderd/x/chatd/title_override_test.go @@ -348,7 +348,8 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback( db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{uuid.Nil}).Return(nil, nil) generated := &generatedChatTitle{} server := titleOverrideTestServer(db, logger) @@ -498,7 +499,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() server := titleOverrideTestServer(db, logger) model, gotConfig, _, err := server.resolveManualTitleModel( @@ -524,7 +526,8 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() server := titleOverrideTestServer(db, logger) model, gotConfig, _, err := server.resolveManualTitleModel( diff --git a/coderd/x/chatd/turn_summary_internal_test.go b/coderd/x/chatd/turn_summary_internal_test.go index 87abdaf881..be3a595799 100644 --- a/coderd/x/chatd/turn_summary_internal_test.go +++ b/coderd/x/chatd/turn_summary_internal_test.go @@ -33,16 +33,15 @@ func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) { OrganizationID: org.ID, }) - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, }) - require.NoError(t, err) modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -102,16 +101,15 @@ func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) { OrganizationID: org.ID, }) - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, }) - require.NoError(t, err) modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Provider: "openai", Model: "test-model", DisplayName: "Test Model", diff --git a/coderd/x/gitsync/worker_test.go b/coderd/x/gitsync/worker_test.go index 833ad5fae9..dabe0d1e8e 100644 --- a/coderd/x/gitsync/worker_test.go +++ b/coderd/x/gitsync/worker_test.go @@ -944,7 +944,7 @@ func TestWorker(t *testing.T) { user := dbgen.User(t, db, database.User{}) org := dbgen.Organization(t, db, database.Organization{}) - // 3. Set up FK chain: chat_providers -> chat_model_configs -> chats. + // 3. Set up FK chain: ai_providers -> chat_model_configs -> chats. _ = dbgen.ChatProvider(t, db, database.ChatProvider{}) modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index f94df3be8d..9587c7e6b3 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -122,14 +122,20 @@ func seedChatDependencies( UserID: user.ID, OrganizationID: org.ID, }) - _ = dbgen.ChatProvider(t, db, database.ChatProvider{ - BaseUrl: safetyNet.URL, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-" + uuid.NewString(), + BaseUrl: safetyNet.URL, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, }) model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - IsDefault: true, + Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + IsDefault: true, }) return user, org, model } @@ -186,21 +192,24 @@ func setOpenAIProviderBaseURL( ) { t.Helper() - provider, err := db.GetChatProviderByProvider(ctx, "openai") - require.NoError(t, err) - - _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - CentralApiKeyEnabled: true, - AllowUserApiKey: false, - AllowCentralApiKeyFallback: false, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, - }) + providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) require.NoError(t, err) + for _, provider := range providers { + if provider.Type != database.AiProviderTypeOpenai { + continue + } + _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: provider.ID, + DisplayName: provider.DisplayName, + Enabled: provider.Enabled, + BaseUrl: baseURL, + Settings: provider.Settings, + SettingsKeyID: provider.SettingsKeyID, + }) + require.NoError(t, err) + return + } + require.Fail(t, "openai provider not found") } func TestSubscribeRelayReconnectsOnDrop(t *testing.T) { diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index f573d080e6..28a04a5aa9 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe } } - userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid) - if err != nil { - return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) - } - for _, userProviderKey := range userProviderKeys { - if strings.TrimSpace(userProviderKey.APIKey) == "" { - continue - } - if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - continue - } - if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ - UserID: userProviderKey.UserID, - ChatProviderID: userProviderKey.ChatProviderID, - APIKey: userProviderKey.APIKey, - ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required - }); err != nil { - return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) - } - log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid) if err != nil { return xerrors.Errorf("get user secrets for user %s: %w", uid, err) @@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } - providers, err := cryptDB.GetChatProviders(ctx) - if err != nil { - return xerrors.Errorf("get chat providers: %w", err) - } - log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers))) - for idx, provider := range providers { - if strings.TrimSpace(provider.APIKey) == "" { - continue - } - if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - continue - } - if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - ID: provider.ID, - }); err != nil { - return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) - } - log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true}) if err != nil { return xerrors.Errorf("get ai providers: %w", err) @@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } } - userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid) - if err != nil { - return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) - } - for _, userProviderKey := range userProviderKeys { - if !userProviderKey.ApiKeyKeyID.Valid { - log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) - continue - } - if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ - UserID: userProviderKey.UserID, - ChatProviderID: userProviderKey.ChatProviderID, - APIKey: userProviderKey.APIKey, - ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id - }); err != nil { - return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) - } - log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) - } - userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid) if err != nil { return xerrors.Errorf("get user secrets for user %s: %w", uid, err) @@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } - providers, err := cryptDB.GetChatProviders(ctx) - if err != nil { - return xerrors.Errorf("get chat providers: %w", err) - } - log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers))) - for idx, provider := range providers { - if !provider.ApiKeyKeyID.Valid { - continue - } - if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - ID: provider.ID, - }); err != nil { - return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) - } - log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true}) if err != nil { return xerrors.Errorf("get ai providers: %w", err) @@ -475,16 +378,10 @@ DELETE FROM user_links DELETE FROM external_auth_links WHERE oauth_access_token_key_id IS NOT NULL OR oauth_refresh_token_key_id IS NOT NULL; -DELETE FROM user_chat_provider_keys - WHERE api_key_key_id IS NOT NULL; DELETE FROM user_ai_provider_keys WHERE api_key_key_id IS NOT NULL; DELETE FROM user_secrets WHERE value_key_id IS NOT NULL; -UPDATE chat_providers - SET api_key = '', - api_key_key_id = NULL - WHERE api_key_key_id IS NOT NULL; UPDATE ai_providers SET settings = NULL, settings_key_id = NULL @@ -502,9 +399,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error { store := database.New(sqlDB) _, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens) if err != nil { - return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err) + return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err) } - log.Info(ctx, "deleted encrypted user tokens and chat provider API keys") + log.Info(ctx, "deleted encrypted user tokens and AI provider API keys") log.Info(ctx, "revoking all active keys") keys, err := store.GetDBCryptKeys(ctx) diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index b66ed6b3de..44cdb5554e 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID return keys, nil } +func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) { + keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs) + if err != nil { + return nil, err + } + for i := range keys { + if err := db.decryptAIProviderKey(&keys[i]); err != nil { + return nil, err + } + } + return keys, nil +} + func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) { if strings.TrimSpace(params.APIKey) == "" { params.ApiKeyKeyID = sql.NullString{} @@ -576,92 +589,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data return key, nil } -func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - provider, err := db.Store.GetChatProviderByID(ctx, id) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) { - provider, err := db.Store.GetChatProviderByProvider(ctx, providerName) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - providers, err := db.Store.GetChatProviders(ctx) - if err != nil { - return nil, err - } - - for i := range providers { - if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil { - return nil, err - } - } - - return providers, nil -} - -func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - providers, err := db.Store.GetEnabledChatProviders(ctx) - if err != nil { - return nil, err - } - - for i := range providers { - if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil { - return nil, err - } - } - - return providers, nil -} - -func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - - provider, err := db.Store.InsertChatProvider(ctx, params) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - - provider, err := db.Store.UpdateChatProvider(ctx, params) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error { return db.decryptField(&key.APIKey, key.ApiKeyKeyID) } @@ -754,57 +681,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params return key, nil } -func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error { - return db.decryptField(&key.APIKey, key.ApiKeyKeyID) -} - -func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - keys, err := db.Store.GetUserChatProviderKeys(ctx, userID) - if err != nil { - return nil, err - } - for i := range keys { - if err := db.decryptUserChatProviderKey(&keys[i]); err != nil { - return nil, err - } - } - return keys, nil -} - -func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.UserChatProviderKey{}, err - } - - key, err := db.Store.UpsertUserChatProviderKey(ctx, params) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := db.decryptUserChatProviderKey(&key); err != nil { - return database.UserChatProviderKey{}, err - } - return key, nil -} - -func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.UserChatProviderKey{}, err - } - - key, err := db.Store.UpdateUserChatProviderKey(ctx, params) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := db.decryptUserChatProviderKey(&key); err != nil { - return database.UserChatProviderKey{}, err - } - return key, nil -} - // decryptMCPServerConfig decrypts all encrypted fields on a // single MCPServerConfig in place. func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 5f70562863..e5a433399b 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1281,6 +1281,18 @@ func TestAIProviderKeys(t *testing.T) { requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey) }) + t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID}) + require.NoError(t, err) + require.Len(t, keys, 1) + requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey) + requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey) + }) + t.Run("DeleteAIProviderKey", func(t *testing.T) { t.Parallel() db, crypt, ciphers := setup(t) @@ -1558,113 +1570,6 @@ func TestMCPServerUserTokens(t *testing.T) { }) } -func TestUserChatProviderKeys(t *testing.T) { - t.Parallel() - ctx := context.Background() - - const ( - //nolint:gosec // test credentials - initialAPIKey = "sk-initial-api-key-value" - //nolint:gosec // test credentials - updatedAPIKey = "sk-updated-api-key-value" - ) - - insertProviderAndKey := func( - t *testing.T, - crypt *dbCrypt, - ciphers []Cipher, - ) (database.ChatProvider, database.UserChatProviderKey) { - t.Helper() - user := dbgen.User(t, crypt, database.User{}) - provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{ - AllowUserApiKey: true, - }, func(params *database.InsertChatProviderParams) { - params.APIKey = "" - }) - - key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: user.ID, - ChatProviderID: provider.ID, - APIKey: initialAPIKey, - }) - require.NoError(t, err) - require.Equal(t, initialAPIKey, key.APIKey) - require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String) - return provider, key - } - - getUserChatProviderKey := func(t *testing.T, store interface { - GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error) - }, userID uuid.UUID, providerID uuid.UUID, - ) database.UserChatProviderKey { - t.Helper() - keys, err := store.GetUserChatProviderKeys(ctx, userID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, providerID, keys[0].ChatProviderID) - return keys[0] - } - - t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) { - t.Parallel() - db, crypt, ciphers := setup(t) - provider, key := insertProviderAndKey(t, crypt, ciphers) - - got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) - require.Equal(t, key.ID, got.ID) - require.Equal(t, initialAPIKey, got.APIKey) - require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) - - rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) - require.NotEqual(t, initialAPIKey, rawKey.APIKey) - requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey) - }) - - t.Run("GetUserChatProviderKeys", func(t *testing.T) { - t.Parallel() - _, crypt, ciphers := setup(t) - _, key := insertProviderAndKey(t, crypt, ciphers) - - keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, key.ID, keys[0].ID) - require.Equal(t, initialAPIKey, keys[0].APIKey) - require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) - }) - - t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) { - t.Parallel() - db, crypt, ciphers := setup(t) - provider, key := insertProviderAndKey(t, crypt, ciphers) - - updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: key.UserID, - ChatProviderID: provider.ID, - APIKey: updatedAPIKey, - }) - require.NoError(t, err) - require.Equal(t, key.ID, updated.ID) - require.Equal(t, key.CreatedAt, updated.CreatedAt) - require.False(t, updated.UpdatedAt.Before(key.UpdatedAt)) - require.Equal(t, updatedAPIKey, updated.APIKey) - require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) - - got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) - require.Equal(t, updatedAPIKey, got.APIKey) - require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) - - keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, updatedAPIKey, keys[0].APIKey) - - rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) - require.NotEqual(t, updatedAPIKey, rawKey.APIKey) - requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey) - }) -} - func TestUserSecrets(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/scripts/check_emdash.sh b/scripts/check_emdash.sh index 4ed7da6175..ffd4e092ff 100755 --- a/scripts/check_emdash.sh +++ b/scripts/check_emdash.sh @@ -60,8 +60,13 @@ else echo "Base ref $base not found locally, fetching $ref..." git fetch origin "$ref" --depth=1 2>/dev/null || true if ! git rev-parse --verify "$base" >/dev/null 2>&1; then - echo "ERROR: could not fetch base ref $base." - exit 1 + if git rev-parse --verify origin/main >/dev/null 2>&1; then + echo "WARNING: could not fetch base ref $base, falling back to origin/main merge base." + base=$(git merge-base HEAD origin/main 2>/dev/null || echo "origin/main") + else + echo "ERROR: could not fetch base ref $base." + exit 1 + fi fi fi diff --git a/site/src/api/api.test.ts b/site/src/api/api.test.ts index 4c046c67d4..5d3abc999d 100644 --- a/site/src/api/api.test.ts +++ b/site/src/api/api.test.ts @@ -284,11 +284,6 @@ describe("api.ts", () => { providers: [], }, ], - [ - "/api/experimental/chats/providers", - () => API.experimental.getChatProviderConfigs(), - [], - ], [ "/api/experimental/chats/model-configs", () => API.experimental.getChatModelConfigs(), @@ -310,10 +305,6 @@ describe("api.ts", () => { "/api/experimental/chats/models", () => API.experimental.getChatModels(), ], - [ - "/api/experimental/chats/providers", - () => API.experimental.getChatProviderConfigs(), - ], [ "/api/experimental/chats/model-configs", () => API.experimental.getChatModelConfigs(), diff --git a/site/src/api/api.ts b/site/src/api/api.ts index a0b16a3f42..99d3e44e38 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -407,11 +407,8 @@ export type DeploymentConfig = Readonly<{ options: TypesGen.SerpentOption[]; }>; -const chatProviderConfigsPath = "/api/experimental/chats/providers"; const aiProviderConfigsPath = "/api/v2/ai/providers"; const chatModelConfigsPath = "/api/experimental/chats/model-configs"; -const userChatProviderConfigsPath = - "/api/experimental/chats/user-provider-configs"; const userSkillsPath = (user: string) => `/api/experimental/users/${encodeURIComponent(user)}/skills`; const userSkillPath = (user: string, name: string) => @@ -3731,42 +3728,6 @@ class ExperimentalApiMethods { ); }; - getChatProviderConfigs = async (): Promise => { - const response = await this.axios.get( - chatProviderConfigsPath, - ); - return response.data; - }; - - createChatProviderConfig = async ( - req: TypesGen.CreateChatProviderConfigRequest, - ): Promise => { - const response = await this.axios.post( - chatProviderConfigsPath, - req, - ); - return response.data; - }; - - updateChatProviderConfig = async ( - providerConfigId: string, - req: TypesGen.UpdateChatProviderConfigRequest, - ): Promise => { - const response = await this.axios.patch( - `${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`, - req, - ); - return response.data; - }; - - deleteChatProviderConfig = async ( - providerConfigId: string, - ): Promise => { - await this.axios.delete( - `${chatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`, - ); - }; - getChatModelConfigs = async (): Promise => { const response = await this.axios.get(chatModelConfigsPath); @@ -3800,34 +3761,6 @@ class ExperimentalApiMethods { ); }; - getUserChatProviderConfigs = async (): Promise< - TypesGen.UserChatProviderConfig[] - > => { - const response = await this.axios.get( - userChatProviderConfigsPath, - ); - return response.data; - }; - - upsertUserChatProviderKey = async ( - providerConfigId: string, - req: TypesGen.CreateUserChatProviderKeyRequest, - ): Promise => { - const response = await this.axios.put( - `${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`, - req, - ); - return response.data; - }; - - deleteUserChatProviderKey = async ( - providerConfigId: string, - ): Promise => { - await this.axios.delete( - `${userChatProviderConfigsPath}/${encodeURIComponent(providerConfigId)}`, - ); - }; - getMCPServerConfigs = async (): Promise => { const response = await this.axios.get(mcpServerConfigsPath); diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx index 0d41cf8607..ec422bca5e 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx @@ -39,12 +39,36 @@ const createProviderConfig = ( updated_at: overrides.updated_at ?? now, }); +const createProviderKey = (providerId: string): TypesGen.AIProviderKey => ({ + id: `key-${providerId}`, + masked: "sk-...test", + created_at: now, +}); + +const toAIProvider = ( + providerConfig: TypesGen.ChatProviderConfig, +): TypesGen.AIProvider => ({ + id: providerConfig.id, + type: providerConfig.provider as TypesGen.AIProviderType, + name: providerConfig.provider, + display_name: providerConfig.display_name, + enabled: providerConfig.enabled, + base_url: providerConfig.base_url ?? "", + api_keys: providerConfig.has_api_key + ? [createProviderKey(providerConfig.id)] + : [], + settings: {}, + created_at: providerConfig.created_at ?? now, + updated_at: providerConfig.updated_at ?? now, +}); + const createModelConfig = ( overrides: Partial & Pick, ): TypesGen.ChatModelConfig => ({ id: overrides.id, provider: overrides.provider, + ai_provider_id: overrides.ai_provider_id, model: overrides.model, display_name: overrides.display_name ?? overrides.model, enabled: overrides.enabled ?? true, @@ -68,11 +92,9 @@ const setupChatSpies = (state: { modelConfigs: TypesGen.ChatModelConfig[]; modelCatalog: TypesGen.ChatModelsResponse; }) => { - spyOn(API.experimental, "getChatProviderConfigs").mockImplementation( - async () => { - return state.providerConfigs; - }, - ); + spyOn(API.experimental, "listAIProviders").mockImplementation(async () => { + return state.providerConfigs.map(toAIProvider); + }); spyOn(API.experimental, "getChatModelConfigs").mockImplementation( async () => { return state.modelConfigs; @@ -82,31 +104,26 @@ const setupChatSpies = (state: { return state.modelCatalog; }); - spyOn(API.experimental, "createChatProviderConfig").mockImplementation( + spyOn(API.experimental, "createAIProvider").mockImplementation( async (req) => { const created = createProviderConfig({ id: `provider-${Date.now()}`, - provider: req.provider ?? "", + provider: req.type ?? "openai", display_name: req.display_name ?? "", - has_api_key: (req.api_key ?? "").trim().length > 0, - central_api_key_enabled: req.central_api_key_enabled ?? true, - allow_user_api_key: req.allow_user_api_key ?? false, - allow_central_api_key_fallback: - req.allow_central_api_key_fallback ?? false, + has_api_key: + req.api_keys?.some((apiKey) => apiKey.trim().length > 0) ?? false, base_url: req.base_url ?? "", + enabled: req.enabled ?? true, source: "database", }); state.providerConfigs = [ - ...state.providerConfigs.filter( - (p) => !(p.id === nilProviderConfigID && p.provider === req.provider), - ), + ...state.providerConfigs.filter((p) => p.id !== created.id), created, ]; - return created; + return toAIProvider(created); }, ); - - spyOn(API.experimental, "updateChatProviderConfig").mockImplementation( + spyOn(API.experimental, "updateAIProvider").mockImplementation( async (providerConfigId, req) => { const idx = state.providerConfigs.findIndex( (p) => p.id === providerConfigId, @@ -122,29 +139,30 @@ const setupChatSpies = (state: { ? req.display_name : current.display_name, has_api_key: - typeof req.api_key === "string" - ? req.api_key.trim().length > 0 - : current.has_api_key, - central_api_key_enabled: - typeof req.central_api_key_enabled === "boolean" - ? req.central_api_key_enabled - : current.central_api_key_enabled, - allow_user_api_key: - typeof req.allow_user_api_key === "boolean" - ? req.allow_user_api_key - : current.allow_user_api_key, - allow_central_api_key_fallback: - typeof req.allow_central_api_key_fallback === "boolean" - ? req.allow_central_api_key_fallback - : current.allow_central_api_key_fallback, + req.api_keys === undefined + ? current.has_api_key + : req.api_keys.some((apiKey) => + apiKey.api_key !== undefined + ? apiKey.api_key.trim().length > 0 + : apiKey.id !== undefined, + ), base_url: typeof req.base_url === "string" ? req.base_url : current.base_url, + enabled: + typeof req.enabled === "boolean" ? req.enabled : current.enabled, updated_at: now, }; state.providerConfigs = state.providerConfigs.map((p, i) => i === idx ? updated : p, ); - return updated; + return toAIProvider(updated); + }, + ); + spyOn(API.experimental, "deleteAIProvider").mockImplementation( + async (providerConfigId) => { + state.providerConfigs = state.providerConfigs.filter( + (p) => p.id !== providerConfigId, + ); }, ); @@ -154,6 +172,7 @@ const setupChatSpies = (state: { id: `model-${state.modelConfigs.length + 1}`, provider: req.provider ?? "", model: req.model, + ai_provider_id: req.ai_provider_id, display_name: req.display_name || req.model, enabled: req.enabled ?? true, context_limit: @@ -181,10 +200,6 @@ const setupChatSpies = (state: { }, ); - // Unused but mock to avoid errors. - spyOn(API.experimental, "deleteChatProviderConfig").mockResolvedValue( - undefined, - ); spyOn(API.experimental, "updateChatModelConfig").mockImplementation( async (modelConfigId, req) => { const idx = state.modelConfigs.findIndex((m) => m.id === modelConfigId);