diff --git a/AGENTS.md b/AGENTS.md index b585a05210..0d0f9635c0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -110,6 +110,9 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID) - For experimental or unstable API paths, skip public doc generation with `// @x-apidocgen {"skip": true}` after the `@Router` annotation. This keeps them out of the published API reference until they stabilize. +- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger + annotations entirely. Do not add `@Summary`, `@Router`, or other + swagger comments to handlers in that file. ### Database Query Naming diff --git a/coderd/coderd.go b/coderd/coderd.go index e61e5fe8c5..1c4f5d6d5c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -782,7 +782,7 @@ func New(options *Options) *API { ReplicaID: api.ID, SubscribeFn: options.ChatSubscribeFn, MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above. - ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues), + ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues), AgentConn: api.agentProvider.AgentConn, AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, CreateWorkspace: api.chatCreateWorkspace, @@ -1221,6 +1221,13 @@ func New(options *Options) *API { r.Delete("/", api.deleteChatUsageLimitGroupOverride) }) }) + r.Route("/user-provider-configs", func(r chi.Router) { + r.Get("/", api.listUserChatProviderConfigs) + r.Route("/{providerConfig}", func(r chi.Router) { + r.Put("/", api.upsertUserChatProviderKey) + r.Delete("/", api.deleteUserChatProviderKey) + }) + }) r.Route("/{chat}", func(r chi.Router) { r.Use(httpmw.ExtractChatParam(options.Database)) r.Get("/", api.getChat) diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index 51f8326779..5223dc17b7 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -10,6 +10,7 @@ const ( CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers + CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config @@ -32,4 +33,5 @@ const ( CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events + CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys ) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 61169cec0e..2ec987c336 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2137,6 +2137,17 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat return q.db.DeleteUserChatCompactionThreshold(ctx, arg) } +func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return err + } + return q.db.DeleteUserChatProviderKey(ctx, arg) +} + func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error { // First get the secret to check ownership secret, err := q.GetUserSecret(ctx, id) @@ -4024,6 +4035,17 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) return q.db.GetUserChatCustomPrompt(ctx, userID) } +func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { + u, err := q.db.GetUserByID(ctx, userID) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil { + return nil, err + } + return q.db.GetUserChatProviderKeys(ctx, userID) +} + func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil { return 0, err @@ -6454,6 +6476,17 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U return q.db.UpdateUserChatCustomPrompt(ctx, arg) } +func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserChatProviderKey{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return database.UserChatProviderKey{}, err + } + return q.db.UpdateUserChatProviderKey(ctx, arg) +} + func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id) } @@ -7181,6 +7214,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error { return q.db.UpsertTemplateUsageStats(ctx) } +func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserChatProviderKey{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return database.UserChatProviderKey{}, err + } + return q.db.UpsertUserChatProviderKey(ctx, arg) +} + func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 64fb95dfbd..8066580b22 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2407,6 +2407,36 @@ func (s *MethodTestSuite) TestUser() { dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt") })) + s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes() + check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key}) + })) + s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns() + })) + s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"} + key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) + })) + s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"} + key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) + })) s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { u := testutil.Fake(s.T(), faker, database.User{}) uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"} diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 377d7351a4..6e13cee446 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -696,6 +696,14 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context return r0 } +func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { + start := time.Now() + r0 := m.s.DeleteUserChatProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc() + return r0 +} + func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error { start := time.Now() r0 := m.s.DeleteUserSecret(ctx, id) @@ -2528,6 +2536,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u return r0, r1 } +func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID) + m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { start := time.Now() r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg) @@ -4560,6 +4576,14 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d return r0, r1 } +func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + start := time.Now() + r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error { start := time.Now() r0 := m.s.UpdateUserDeletedByID(ctx, id) @@ -5152,6 +5176,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error { return r0 } +func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + start := time.Now() + r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { start := time.Now() r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 3a79163601..4fb757678e 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1171,6 +1171,20 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg) } +// DeleteUserChatProviderKey mocks base method. +func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey. +func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg) +} + // DeleteUserSecret mocks base method. func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error { m.ctrl.T.Helper() @@ -4729,6 +4743,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID) } +// GetUserChatProviderKeys mocks base method. +func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID) + ret0, _ := ret[0].([]database.UserChatProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys. +func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID) +} + // GetUserChatSpendInPeriod mocks base method. func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { m.ctrl.T.Helper() @@ -8605,6 +8634,21 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg) } +// UpdateUserChatProviderKey mocks base method. +func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg) + ret0, _ := ret[0].(database.UserChatProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey. +func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg) +} + // UpdateUserDeletedByID mocks base method. func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error { m.ctrl.T.Helper() @@ -9671,6 +9715,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx) } +// UpsertUserChatProviderKey mocks base method. +func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg) + ret0, _ := ret[0].(database.UserChatProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey. +func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg) +} + // UpsertWebpushVAPIDKeys mocks base method. func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 75605f9bcf..68a17e99a5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1341,7 +1341,11 @@ CREATE TABLE chat_providers ( created_at timestamp with time zone DEFAULT now() NOT NULL, updated_at timestamp with time zone DEFAULT now() NOT NULL, base_url text DEFAULT ''::text NOT NULL, - CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))) + central_api_key_enabled boolean DEFAULT true NOT NULL, + allow_user_api_key boolean DEFAULT false NOT NULL, + allow_central_api_key_fallback boolean DEFAULT false NOT NULL, + CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))), + CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key)))) ); COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted'; @@ -2752,6 +2756,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.'; +CREATE TABLE user_chat_provider_keys ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + user_id uuid NOT NULL, + chat_provider_id uuid NOT NULL, + api_key text NOT NULL, + api_key_key_id text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text)) +); + CREATE TABLE user_configs ( user_id uuid NOT NULL, key character varying(256) NOT NULL, @@ -3548,6 +3563,12 @@ ALTER TABLE ONLY usage_events_daily ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); +ALTER TABLE ONLY user_chat_provider_keys + ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY user_chat_provider_keys + ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); + ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); @@ -4258,6 +4279,15 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_chat_provider_keys + ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY user_chat_provider_keys + ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; + +ALTER TABLE ONLY user_chat_provider_keys + ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 4f7ec37f0b..c326777574 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -92,6 +92,9 @@ const ( ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE; ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT; ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; + ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/migrations/000459_provider_key_policy.down.sql b/coderd/database/migrations/000459_provider_key_policy.down.sql new file mode 100644 index 0000000000..b7a5bc2a55 --- /dev/null +++ b/coderd/database/migrations/000459_provider_key_policy.down.sql @@ -0,0 +1,8 @@ +DROP TABLE IF EXISTS user_chat_provider_keys; + +ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy; + +ALTER TABLE chat_providers + DROP COLUMN IF EXISTS central_api_key_enabled, + DROP COLUMN IF EXISTS allow_user_api_key, + DROP COLUMN IF EXISTS allow_central_api_key_fallback; diff --git a/coderd/database/migrations/000459_provider_key_policy.up.sql b/coderd/database/migrations/000459_provider_key_policy.up.sql new file mode 100644 index 0000000000..f4a7655c1b --- /dev/null +++ b/coderd/database/migrations/000459_provider_key_policy.up.sql @@ -0,0 +1,24 @@ +ALTER TABLE chat_providers + ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE, + ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE chat_providers + ADD CONSTRAINT valid_credential_policy CHECK ( + (central_api_key_enabled OR allow_user_api_key) AND + ( + NOT allow_central_api_key_fallback OR + (central_api_key_enabled AND allow_user_api_key) + ) + ); + +CREATE TABLE user_chat_provider_keys ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE, + api_key TEXT NOT NULL CHECK (api_key != ''), + api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, chat_provider_id) +); diff --git a/coderd/database/migrations/testdata/fixtures/000459_provider_key_policy.up.sql b/coderd/database/migrations/testdata/fixtures/000459_provider_key_policy.up.sql new file mode 100644 index 0000000000..68458a3066 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000459_provider_key_policy.up.sql @@ -0,0 +1,16 @@ +INSERT INTO user_chat_provider_keys ( + user_id, + chat_provider_id, + api_key, + created_at, + updated_at +) +SELECT + id, + '0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7', + 'fixture-test-key', + '2025-01-01 00:00:00+00', + '2025-01-01 00:00:00+00' +FROM users +ORDER BY created_at, id +LIMIT 1; diff --git a/coderd/database/models.go b/coderd/database/models.go index b66f6e5e17..f34c5b6b71 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4264,12 +4264,15 @@ type ChatProvider struct { DisplayName string `db:"display_name" json:"display_name"` APIKey string `db:"api_key" json:"api_key"` // The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - BaseUrl string `db:"base_url" json:"base_url"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` + Enabled bool `db:"enabled" json:"enabled"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + BaseUrl string `db:"base_url" json:"base_url"` + CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` + AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` + AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` } type ChatQueuedMessage struct { @@ -5222,6 +5225,16 @@ type User struct { ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"` } +type UserChatProviderKey struct { + ID uuid.UUID `db:"id" json:"id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + type UserConfig struct { UserID uuid.UUID `db:"user_id" json:"user_id"` Key string `db:"key" json:"key"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 50bc6dc9e5..c34b54d127 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -150,6 +150,7 @@ type sqlcQuerier interface { DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error) DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error) DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error + DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error DeleteUserSecret(ctx context.Context, id uuid.UUID) error DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error @@ -577,6 +578,7 @@ type sqlcQuerier interface { GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) + GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) // Returns the minimum (most restrictive) group limit for a user. @@ -927,6 +929,7 @@ type sqlcQuerier interface { UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error) UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error) + UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error @@ -1015,6 +1018,7 @@ type sqlcQuerier interface { // used to store the data, and the minutes are summed for each user and template // combination. The result is stored in the template_usage_stats table. UpsertTemplateUsageStats(ctx context.Context) error + UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error) UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 126d24010c..db31e88c69 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -1261,10 +1261,11 @@ func TestGetAuthorizedChats(t *testing.T) { // Create FK dependencies: a chat provider and model config. ctx := testutil.Context(t, testutil.WaitMedium) _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -9456,10 +9457,11 @@ func TestInsertChatMessages(t *testing.T) { provider := "openai" _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: provider, - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: provider, + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -9621,10 +9623,11 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) { // A chat_providers row is required as a FK for model configs. _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -9992,10 +9995,11 @@ func TestGetPRInsights(t *testing.T) { user := dbgen.User(t, store, database.User{}) _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "anthropic", - DisplayName: "Anthropic", - APIKey: "test-key", - Enabled: true, + Provider: "anthropic", + DisplayName: "Anthropic", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -10516,10 +10520,11 @@ func TestChatPinOrderQueries(t *testing.T) { // timed test context doesn't tick during DB init. bg := context.Background() _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -10696,10 +10701,11 @@ func TestChatLabels(t *testing.T) { owner := dbgen.User(t, db, database.User{}) _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -10907,10 +10913,11 @@ func TestChatHasUnread(t *testing.T) { user := dbgen.User(t, store, database.User{}) _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 15d11ac4b0..faef8406f3 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3799,7 +3799,7 @@ func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) e const getChatProviderByID = `-- name: GetChatProviderByID :one SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback FROM chat_providers WHERE @@ -3820,13 +3820,16 @@ func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (Cha &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ) return i, err } const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback FROM chat_providers WHERE @@ -3847,13 +3850,16 @@ func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider str &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ) return i, err } const getChatProviders = `-- name: GetChatProviders :many SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback FROM chat_providers ORDER BY @@ -3880,6 +3886,9 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ); err != nil { return nil, err } @@ -3896,7 +3905,7 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback FROM chat_providers WHERE @@ -3925,6 +3934,9 @@ func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvide &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ); err != nil { return nil, err } @@ -3947,7 +3959,10 @@ INSERT INTO chat_providers ( base_url, api_key_key_id, created_by, - enabled + enabled, + central_api_key_enabled, + allow_user_api_key, + allow_central_api_key_fallback ) VALUES ( $1::text, $2::text, @@ -3955,20 +3970,26 @@ INSERT INTO chat_providers ( $4::text, $5::text, $6::uuid, - $7::boolean + $7::boolean, + $8::boolean, + $9::boolean, + $10::boolean ) RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback ` type InsertChatProviderParams struct { - Provider string `db:"provider" json:"provider"` - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` + Provider string `db:"provider" json:"provider"` + DisplayName string `db:"display_name" json:"display_name"` + APIKey string `db:"api_key" json:"api_key"` + BaseUrl string `db:"base_url" json:"base_url"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` + Enabled bool `db:"enabled" json:"enabled"` + CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` + AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` + AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` } func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) { @@ -3980,6 +4001,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi arg.ApiKeyKeyID, arg.CreatedBy, arg.Enabled, + arg.CentralApiKeyEnabled, + arg.AllowUserApiKey, + arg.AllowCentralApiKeyFallback, ) var i ChatProvider err := row.Scan( @@ -3993,6 +4017,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ) return i, err } @@ -4006,20 +4033,26 @@ SET base_url = $3::text, api_key_key_id = $4::text, enabled = $5::boolean, + central_api_key_enabled = $6::boolean, + allow_user_api_key = $7::boolean, + allow_central_api_key_fallback = $8::boolean, updated_at = NOW() WHERE - id = $6::uuid + id = $9::uuid RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url + id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback ` type UpdateChatProviderParams struct { - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - Enabled bool `db:"enabled" json:"enabled"` - ID uuid.UUID `db:"id" json:"id"` + DisplayName string `db:"display_name" json:"display_name"` + APIKey string `db:"api_key" json:"api_key"` + BaseUrl string `db:"base_url" json:"base_url"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + Enabled bool `db:"enabled" json:"enabled"` + CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` + AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` + AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` + ID uuid.UUID `db:"id" json:"id"` } func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) { @@ -4029,6 +4062,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi arg.BaseUrl, arg.ApiKeyKeyID, arg.Enabled, + arg.CentralApiKeyEnabled, + arg.AllowUserApiKey, + arg.AllowCentralApiKeyFallback, arg.ID, ) var i ChatProvider @@ -4043,6 +4079,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi &i.CreatedAt, &i.UpdatedAt, &i.BaseUrl, + &i.CentralApiKeyEnabled, + &i.AllowUserApiKey, + &i.AllowCentralApiKeyFallback, ) return i, err } @@ -22617,6 +22656,126 @@ func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretP return i, err } +const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec +DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2 +` + +type DeleteUserChatProviderKeyParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` +} + +func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error { + _, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID) + return err +} + +const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many +SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC +` + +func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserChatProviderKey + for rows.Next() { + var i UserChatProviderKey + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.ChatProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one +UPDATE user_chat_provider_keys +SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW() +WHERE user_id = $3 AND chat_provider_id = $4 +RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at +` + +type UpdateUserChatProviderKeyParams struct { + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` +} + +func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) { + row := q.db.QueryRowContext(ctx, updateUserChatProviderKey, + arg.APIKey, + arg.ApiKeyKeyID, + arg.UserID, + arg.ChatProviderID, + ) + var i UserChatProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChatProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one +INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) +VALUES ($1, $2, $3, $4::text) +ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET + api_key = $3, + api_key_key_id = $4::text, + updated_at = NOW() +RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at +` + +type UpsertUserChatProviderKeyParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` +} + +func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) { + row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey, + arg.UserID, + arg.ChatProviderID, + arg.APIKey, + arg.ApiKeyKeyID, + ) + var i UserChatProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.ChatProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const allUserIDs = `-- name: AllUserIDs :many SELECT DISTINCT id FROM USERS WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END diff --git a/coderd/database/queries/chatproviders.sql b/coderd/database/queries/chatproviders.sql index 228fbf3b28..02edac049d 100644 --- a/coderd/database/queries/chatproviders.sql +++ b/coderd/database/queries/chatproviders.sql @@ -40,7 +40,10 @@ INSERT INTO chat_providers ( base_url, api_key_key_id, created_by, - enabled + enabled, + central_api_key_enabled, + allow_user_api_key, + allow_central_api_key_fallback ) VALUES ( @provider::text, @display_name::text, @@ -48,7 +51,10 @@ INSERT INTO chat_providers ( @base_url::text, sqlc.narg('api_key_key_id')::text, sqlc.narg('created_by')::uuid, - @enabled::boolean + @enabled::boolean, + @central_api_key_enabled::boolean, + @allow_user_api_key::boolean, + @allow_central_api_key_fallback::boolean ) RETURNING *; @@ -62,6 +68,9 @@ SET base_url = @base_url::text, api_key_key_id = sqlc.narg('api_key_key_id')::text, enabled = @enabled::boolean, + central_api_key_enabled = @central_api_key_enabled::boolean, + allow_user_api_key = @allow_user_api_key::boolean, + allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean, updated_at = NOW() WHERE id = @id::uuid diff --git a/coderd/database/queries/userchatproviderkeys.sql b/coderd/database/queries/userchatproviderkeys.sql new file mode 100644 index 0000000000..38c177156e --- /dev/null +++ b/coderd/database/queries/userchatproviderkeys.sql @@ -0,0 +1,20 @@ +-- name: GetUserChatProviderKeys :many +SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC; + +-- name: UpsertUserChatProviderKey :one +INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) +VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text) +ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET + api_key = @api_key, + api_key_key_id = sqlc.narg('api_key_key_id')::text, + updated_at = NOW() +RETURNING *; + +-- name: UpdateUserChatProviderKey :one +UPDATE user_chat_provider_keys +SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW() +WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id +RETURNING *; + +-- name: DeleteUserChatProviderKey :exec +DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 8a123be0cb..a329058947 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -90,6 +90,8 @@ const ( UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type); UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); + UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); + UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id); UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type); diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 595648beea..6d352fdc74 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -520,6 +520,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { // EXPERIMENTAL: this endpoint is experimental and is subject to change. func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + apiKey := httpmw.APIKey(r) //nolint:gocritic // System context required to read enabled chat models. systemCtx := dbauthz.AsSystemRestricted(ctx) @@ -555,14 +556,24 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) { configuredProviders := make( []chatprovider.ConfiguredProvider, 0, len(enabledProviders), ) + enabledProviderNames := make(map[string]struct{}, len(enabledProviders)) for _, provider := range enabledProviders { configuredProviders = append( configuredProviders, chatprovider.ConfiguredProvider{ - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, + ProviderID: provider.ID, + Provider: provider.Provider, + APIKey: provider.APIKey, + BaseURL: provider.BaseUrl, + CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserAPIKey: provider.AllowUserApiKey, + AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, }, ) + normalizedProvider := chatprovider.NormalizeProvider(provider.Provider) + if normalizedProvider == "" { + continue + } + enabledProviderNames[normalizedProvider] = struct{}{} } configuredModels := make( []chatprovider.ConfiguredModel, 0, len(enabledModels), @@ -575,18 +586,38 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) { }) } - keys := chatprovider.MergeProviderAPIKeys( - chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), + userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to load user chat provider keys.", + Detail: err.Error(), + }) + return + } + userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) + for _, userKey := range userKeyRows { + userKeys = append(userKeys, chatprovider.UserProviderKey{ + ChatProviderID: userKey.ChatProviderID, + APIKey: userKey.APIKey, + }) + } + + _, providerAvailability := chatprovider.ResolveUserProviderKeys( + ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), configuredProviders, + userKeys, ) - catalog := chatprovider.NewModelCatalog(keys) + catalog := chatprovider.NewModelCatalog() var response codersdk.ChatModelsResponse if configured, ok := catalog.ListConfiguredModels( - configuredProviders, configuredModels, + configuredProviders, configuredModels, providerAvailability, enabledProviderNames, ); ok { response = configured } else { - response = catalog.ListConfiguredProviderAvailability(configuredProviders) + response = catalog.ListConfiguredProviderAvailability( + providerAvailability, + enabledProviderNames, + ) } httpapi.Write(ctx, rw, http.StatusOK, response) @@ -3926,9 +3957,13 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { []chatprovider.ConfiguredProvider, 0, len(enabledProviders), ) for _, provider := range enabledProviders { + normalizedProvider := normalizeChatProvider(provider.Provider) + if normalizedProvider == "" { + continue + } enabledConfiguredProviders = append( enabledConfiguredProviders, chatprovider.ConfiguredProvider{ - Provider: provider.Provider, + Provider: normalizedProvider, APIKey: provider.APIKey, BaseURL: provider.BaseUrl, }, @@ -3936,7 +3971,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { } effectiveKeys := chatprovider.MergeProviderAPIKeys( - chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), + ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), enabledConfiguredProviders, ) effectiveKeys = chatprovider.MergeProviderAPIKeys( @@ -3952,7 +3987,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { resp, convertChatProviderConfig( configured, - effectiveKeys.APIKey(provider) != "", + api.hasEffectiveProviderAPIKey(ctx, configured), codersdk.ChatProviderConfigSourceDatabase, ), ) @@ -3968,13 +4003,16 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { } resp = append(resp, codersdk.ChatProviderConfig{ - ID: uuid.Nil, - Provider: provider, - DisplayName: chatprovider.ProviderDisplayName(provider), - Enabled: enabled, - HasAPIKey: hasAPIKey, - BaseURL: effectiveKeys.BaseURL(provider), - Source: source, + ID: uuid.Nil, + Provider: provider, + DisplayName: chatprovider.ProviderDisplayName(provider), + Enabled: enabled, + HasAPIKey: hasAPIKey, + CentralAPIKeyEnabled: true, + AllowUserAPIKey: false, + AllowCentralAPIKeyFallback: false, + BaseURL: effectiveKeys.BaseURL(provider), + Source: source, }) } @@ -3984,6 +4022,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() apiKey := httpmw.APIKey(r) + var inserted database.ChatProvider if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { httpapi.Forbidden(rw) return @@ -4003,6 +4042,14 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { return } + if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key too large.", + Detail: err.Error(), + }) + return + } + enabled := true if req.Enabled != nil { enabled = *req.Enabled @@ -4016,14 +4063,57 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { return } - inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: provider, - DisplayName: strings.TrimSpace(req.DisplayName), - APIKey: strings.TrimSpace(req.APIKey), - BaseUrl: baseURL, - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil}, - Enabled: enabled, + centralAPIKeyEnabled := true + if req.CentralAPIKeyEnabled != nil { + centralAPIKeyEnabled = *req.CentralAPIKeyEnabled + } + allowUserAPIKey := false + if req.AllowUserAPIKey != nil { + allowUserAPIKey = *req.AllowUserAPIKey + } + allowCentralAPIKeyFallback := false + if req.AllowCentralAPIKeyFallback != nil { + allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback + } + + if err := validateChatProviderCredentialPolicy( + centralAPIKeyEnabled, + allowUserAPIKey, + allowCentralAPIKeyFallback, + ); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid credential policy.", + Detail: err.Error(), + }) + return + } + + if err := validateChatProviderCentralAPIKey( + centralAPIKeyEnabled, + api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ + Provider: provider, + APIKey: strings.TrimSpace(req.APIKey), + BaseUrl: baseURL, + CentralApiKeyEnabled: centralAPIKeyEnabled, + }, uuid.Nil), + ); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: err.Error(), + }) + return + } + + inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: provider, + DisplayName: strings.TrimSpace(req.DisplayName), + APIKey: strings.TrimSpace(req.APIKey), + BaseUrl: baseURL, + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil}, + Enabled: enabled, + CentralApiKeyEnabled: centralAPIKeyEnabled, + AllowUserApiKey: allowUserAPIKey, + AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, }) if err != nil { switch { @@ -4064,6 +4154,10 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + var ( + existing database.ChatProvider + updated database.ChatProvider + ) if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { httpapi.Forbidden(rw) return @@ -4105,7 +4199,17 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { apiKey := existing.APIKey apiKeyKeyID := existing.ApiKeyKeyID if req.APIKey != nil { - apiKey = strings.TrimSpace(*req.APIKey) + trimmedAPIKey := strings.TrimSpace(*req.APIKey) + if trimmedAPIKey != "" { + if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key too large.", + Detail: err.Error(), + }) + return + } + } + apiKey = trimmedAPIKey apiKeyKeyID = sql.NullString{} } baseURL := existing.BaseUrl @@ -4120,13 +4224,57 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { } } - updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: displayName, - APIKey: apiKey, - BaseUrl: baseURL, - ApiKeyKeyID: apiKeyKeyID, - Enabled: enabled, - ID: existing.ID, + centralAPIKeyEnabled := existing.CentralApiKeyEnabled + if req.CentralAPIKeyEnabled != nil { + centralAPIKeyEnabled = *req.CentralAPIKeyEnabled + } + allowUserAPIKey := existing.AllowUserApiKey + if req.AllowUserAPIKey != nil { + allowUserAPIKey = *req.AllowUserAPIKey + } + allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback + if req.AllowCentralAPIKeyFallback != nil { + allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback + } + + if err := validateChatProviderCredentialPolicy( + centralAPIKeyEnabled, + allowUserAPIKey, + allowCentralAPIKeyFallback, + ); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid credential policy.", + Detail: err.Error(), + }) + return + } + + if err := validateChatProviderCentralAPIKey( + centralAPIKeyEnabled, + api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ + ID: existing.ID, + Provider: existing.Provider, + APIKey: apiKey, + BaseUrl: baseURL, + CentralApiKeyEnabled: centralAPIKeyEnabled, + }, existing.ID), + ); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: err.Error(), + }) + return + } + + updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ + DisplayName: displayName, + APIKey: apiKey, + BaseUrl: baseURL, + ApiKeyKeyID: apiKeyKeyID, + Enabled: enabled, + CentralApiKeyEnabled: centralAPIKeyEnabled, + AllowUserApiKey: allowUserAPIKey, + AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, + ID: existing.ID, }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -4197,6 +4345,169 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) } +func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + apiKey = httpmw.APIKey(r) + ) + + //nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials. + chatdCtx := dbauthz.AsChatd(ctx) + providers, err := api.Database.GetChatProviders(chatdCtx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to list chat providers.", + Detail: err.Error(), + }) + return + } + + userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to list user chat provider keys.", + Detail: err.Error(), + }) + return + } + + hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys)) + for _, userKey := range userKeys { + hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true + } + + resp := make([]codersdk.UserChatProviderConfig, 0, len(providers)) + for _, provider := range providers { + if !provider.Enabled || !provider.AllowUserApiKey { + continue + } + hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID] + hasCentralAPIKeyFallback := provider.Enabled && + provider.AllowCentralApiKeyFallback && + api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil) + resp = append( + resp, + convertUserChatProviderConfig( + provider, + hasUserAPIKey, + hasCentralAPIKeyFallback, + ), + ) + } + + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + +func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + apiKey = httpmw.APIKey(r) + ) + + providerID, ok := parseChatProviderID(rw, r) + if !ok { + return + } + + //nolint:gocritic // Non-admin users need to validate provider availability before storing their own key. + provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get chat provider.", + Detail: err.Error(), + }) + return + } + if !provider.Enabled { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Provider is disabled.", + }) + return + } + if !provider.AllowUserApiKey { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Provider does not allow user API keys.", + }) + return + } + + var req codersdk.CreateUserChatProviderKeyRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + trimmedAPIKey := strings.TrimSpace(req.APIKey) + if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key too large.", + Detail: err.Error(), + }) + return + } + if trimmedAPIKey == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key is required.", + }) + return + } + + if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ + UserID: apiKey.UserID, + ChatProviderID: providerID, + APIKey: trimmedAPIKey, + ApiKeyKeyID: sql.NullString{}, + }); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to save user chat provider key.", + Detail: err.Error(), + }) + return + } + + hasCentralAPIKeyFallback := provider.Enabled && + provider.AllowCentralApiKeyFallback && + api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil) + httpapi.Write( + ctx, + rw, + http.StatusOK, + convertUserChatProviderConfig( + provider, + true, + hasCentralAPIKeyFallback, + ), + ) +} + +func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + apiKey = httpmw.APIKey(r) + ) + + providerID, ok := parseChatProviderID(rw, r) + if !ok { + return + } + + if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{ + UserID: apiKey.UserID, + ChatProviderID: providerID, + }); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to delete user chat provider key.", + Detail: err.Error(), + }) + return + } + + rw.WriteHeader(http.StatusNoContent) +} + func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -4737,15 +5048,37 @@ func convertChatProviderConfig( } return codersdk.ChatProviderConfig{ - ID: provider.ID, - Provider: provider.Provider, - DisplayName: displayName, - Enabled: provider.Enabled, - HasAPIKey: hasAPIKey, - BaseURL: strings.TrimSpace(provider.BaseUrl), - Source: source, - CreatedAt: provider.CreatedAt, - UpdatedAt: provider.UpdatedAt, + ID: provider.ID, + Provider: provider.Provider, + DisplayName: displayName, + Enabled: provider.Enabled, + HasAPIKey: hasAPIKey, + CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserAPIKey: provider.AllowUserApiKey, + AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, + BaseURL: strings.TrimSpace(provider.BaseUrl), + Source: source, + CreatedAt: provider.CreatedAt, + UpdatedAt: provider.UpdatedAt, + } +} + +func convertUserChatProviderConfig( + provider database.ChatProvider, + hasUserAPIKey bool, + hasCentralAPIKeyFallback bool, +) codersdk.UserChatProviderConfig { + displayName := strings.TrimSpace(provider.DisplayName) + if displayName == "" { + displayName = chatprovider.ProviderDisplayName(provider.Provider) + } + + return codersdk.UserChatProviderConfig{ + ProviderID: provider.ID, + Provider: provider.Provider, + DisplayName: displayName, + HasUserAPIKey: hasUserAPIKey, + HasCentralAPIKeyFallback: hasCentralAPIKeyFallback, } } @@ -4904,26 +5237,80 @@ func chatProviderValidationDetail() string { return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "." } -func chatProviderAPIKeysFromDeploymentValues( - deploymentValues *codersdk.DeploymentValues, -) chatprovider.ProviderAPIKeys { - _ = deploymentValues - // For now, we'll just manage configs in the UI. - // We should probably not be reusing the AI bridge configs anyways. - return chatprovider.ProviderAPIKeys{ - // OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(), - // Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(), - // BaseURLByProvider: map[string]string{ - // "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(), - // "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(), - // }, +const maxChatProviderAPIKeySize = 10240 // 10 KB + +func validateChatProviderAPIKeySize(apiKey string) error { + if len(apiKey) > maxChatProviderAPIKeySize { + return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize) } + return nil +} + +//nolint:revive // This helper validates the explicit credential policy tuple. +func validateChatProviderCredentialPolicy( + centralEnabled, allowUserKey, allowFallback bool, +) error { + if !centralEnabled && !allowUserKey { + return xerrors.New( + "At least one credential source must be enabled: central API key or user API key.", + ) + } + if allowFallback && !centralEnabled { + return xerrors.New( + "Central API key fallback requires central API key to be enabled.", + ) + } + if allowFallback && !allowUserKey { + return xerrors.New( + "Central API key fallback requires user API key to be enabled.", + ) + } + return nil +} + +//nolint:revive // This helper validates central-key requirements. +func validateChatProviderCentralAPIKey( + centralEnabled bool, + hasCentralAPIKey bool, +) error { + if centralEnabled && !hasCentralAPIKey { + return xerrors.New( + "API key is required when central API key is enabled.", + ) + } + return nil +} + +// ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat +// provider API keys. +func ChatProviderAPIKeysFromDeploymentValues( + _ *codersdk.DeploymentValues, +) chatprovider.ProviderAPIKeys { + // AI bridge deployment config is intentionally not reused for chat + // provider credentials. Bridge keys serve the AI task subsystem and + // should not silently broaden into chat execution paths. + return chatprovider.ProviderAPIKeys{} } func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool { + return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil) +} + +func (api *API) hasEffectiveCentralProviderAPIKey( + ctx context.Context, + provider database.ChatProvider, + excludeProviderID uuid.UUID, +) bool { + if !provider.CentralApiKeyEnabled { + return false + } if strings.TrimSpace(provider.APIKey) != "" { return true } + deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues) + if deploymentKeys.APIKey(provider.Provider) != "" { + return true + } if api.chatDaemon == nil { return false } @@ -4945,6 +5332,9 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas []chatprovider.ConfiguredProvider, 0, len(enabledProviders), ) for _, configured := range enabledProviders { + if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID { + continue + } enabledConfiguredProviders = append( enabledConfiguredProviders, chatprovider.ConfiguredProvider{ Provider: configured.Provider, @@ -4955,7 +5345,7 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas } effectiveKeys := chatprovider.MergeProviderAPIKeys( - chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), + deploymentKeys, enabledConfiguredProviders, ) return effectiveKeys.APIKey(provider.Provider) != "" diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 2d5680c4f1..4e9b23afc8 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -34,12 +35,16 @@ import ( "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" + "github.com/coder/serpent" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) +const chatProviderAPIKeySizeLimit = 10240 + func chatDeploymentValues(t testing.TB) *codersdk.DeploymentValues { t.Helper() @@ -57,6 +62,18 @@ func newChatClient(t testing.TB) *codersdk.ExperimentalClient { return codersdk.NewExperimentalClient(client) } +func newChatClientWithDeploymentValues( + t testing.TB, + values *codersdk.DeploymentValues, +) *codersdk.ExperimentalClient { + t.Helper() + + client := coderdtest.New(t, &coderdtest.Options{ + DeploymentValues: values, + }) + return codersdk.NewExperimentalClient(client) +} + func newChatClientWithDatabase(t testing.TB) (*codersdk.ExperimentalClient, database.Store) { t.Helper() @@ -813,6 +830,180 @@ func TestListChatModels(t *testing.T) { _, err := unauthenticatedClient.ListChatModels(ctx) requireSDKError(t, err, http.StatusUnauthorized) }) + + t.Run("CentralOnlyProviderAvailable", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + + var openAIProvider *codersdk.ChatModelProvider + for i := range models.Providers { + if models.Providers[i].Provider == "openai" { + openAIProvider = &models.Providers[i] + break + } + } + require.NotNil(t, openAIProvider) + require.True(t, openAIProvider.Available) + }) + + t.Run("UserOnlyProviderRequiresUserKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + contextLimit := int64(4096) + _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + Model: "claude-sonnet", + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + + var anthropicProvider *codersdk.ChatModelProvider + for i := range models.Providers { + if models.Providers[i].Provider == "anthropic" { + anthropicProvider = &models.Providers[i] + break + } + } + require.NotNil(t, anthropicProvider) + require.False(t, anthropicProvider.Available) + require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-api-key", + }) + require.NoError(t, err) + + models, err = client.ListChatModels(ctx) + require.NoError(t, err) + + anthropicProvider = nil + for i := range models.Providers { + if models.Providers[i].Provider == "anthropic" { + anthropicProvider = &models.Providers[i] + break + } + } + require.NotNil(t, anthropicProvider) + require.True(t, anthropicProvider.Available) + }) + + t.Run("CentralAndUserWithFallback", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "google", + APIKey: "central-api-key", + CentralAPIKeyEnabled: ptr.Ref(true), + AllowUserAPIKey: ptr.Ref(true), + AllowCentralAPIKeyFallback: ptr.Ref(true), + }) + require.NoError(t, err) + + contextLimit := int64(4096) + _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "google", + Model: "gemini-1.5-pro", + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + + var googleProvider *codersdk.ChatModelProvider + for i := range models.Providers { + if models.Providers[i].Provider == "google" { + googleProvider = &models.Providers[i] + break + } + } + require.NotNil(t, googleProvider) + require.True(t, googleProvider.Available) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-api-key", + }) + require.NoError(t, err) + + models, err = client.ListChatModels(ctx) + require.NoError(t, err) + + googleProvider = nil + for i := range models.Providers { + if models.Providers[i].Provider == "google" { + googleProvider = &models.Providers[i] + break + } + } + require.NotNil(t, googleProvider) + require.True(t, googleProvider.Available) + }) + + t.Run("DisabledProvidersAndModelsAreFilteredOut", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-key", + }) + require.NoError(t, err) + + contextLimit := int64(4096) + _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "gpt-4o-mini", + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + require.Len(t, models.Providers, 1) + require.Equal(t, "openai", models.Providers[0].Provider) + require.Len(t, models.Providers[0].Models, 1) + require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model) + + enabled := false + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + Enabled: &enabled, + }) + require.NoError(t, err) + + models, err = client.ListChatModels(ctx) + require.NoError(t, err) + require.Empty(t, models.Providers) + }) } func TestWatchChats(t *testing.T) { @@ -1212,6 +1403,34 @@ func TestListChatProviders(t *testing.T) { require.True(t, openAIProvider.HasAPIKey) }) + t.Run("IgnoresDeploymentKeyWhenCentralKeyDisabled", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + require.False(t, provider.HasAPIKey) + + providers, err := client.ListChatProviders(ctx) + require.NoError(t, err) + for _, listed := range providers { + if listed.Provider == "openai" { + require.False(t, listed.HasAPIKey) + return + } + } + t.Fatal("openai provider not found") + }) + t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() @@ -1301,6 +1520,135 @@ func TestCreateChatProvider(t *testing.T) { }) requireSDKError(t, err, http.StatusForbidden) }) + + t.Run("DefaultsPolicyFields", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + require.True(t, provider.CentralAPIKeyEnabled) + require.False(t, provider.AllowUserAPIKey) + require.False(t, provider.AllowCentralAPIKeyFallback) + }) + + t.Run("UserOnlyDoesNotRequireCentralKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + require.False(t, provider.CentralAPIKeyEnabled) + require.True(t, provider.AllowUserAPIKey) + require.False(t, provider.AllowCentralAPIKeyFallback) + require.False(t, provider.HasAPIKey) + }) + + t.Run("RejectsDeploymentBackedCentralKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) + }) + + t.Run("RejectsInvalidPolicyTuple", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + testCases := []struct { + name string + central bool + user bool + fallback bool + }{ + { + name: "NoneEnabled", + central: false, + user: false, + fallback: false, + }, + { + name: "FallbackWithoutCentral", + central: false, + user: true, + fallback: true, + }, + { + name: "FallbackWithoutUser", + central: true, + user: false, + fallback: true, + }, + } + + for _, testCase := range testCases { + _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + CentralAPIKeyEnabled: ptr.Ref(testCase.central), + AllowUserAPIKey: ptr.Ref(testCase.user), + AllowCentralAPIKeyFallback: ptr.Ref(testCase.fallback), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equalf(t, "Invalid credential policy.", sdkErr.Message, "case %s", testCase.name) + } + }) + + t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit+1), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key too large.", sdkErr.Message) + require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) + }) + + t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit), + }) + require.NoError(t, err) + require.True(t, provider.HasAPIKey) + }) } func TestUpdateChatProvider(t *testing.T) { @@ -1387,6 +1735,184 @@ func TestUpdateChatProvider(t *testing.T) { }) requireSDKError(t, err, http.StatusForbidden) }) + + t.Run("AppliesPolicyOverrides", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + updated, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + require.True(t, updated.AllowUserAPIKey) + require.False(t, updated.CentralAPIKeyEnabled) + require.False(t, updated.HasAPIKey) + }) + + t.Run("RejectsDeploymentBackedCentralKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + CentralAPIKeyEnabled: ptr.Ref(true), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) + }) + + t.Run("RejectsClearingLastCentralKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + APIKey: ptr.Ref(""), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) + }) + + t.Run("RejectsEnablingCentralKeyWithoutKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + CentralAPIKeyEnabled: ptr.Ref(true), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) + }) + + t.Run("RejectsInvalidPolicyTuple", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + testCases := []struct { + name string + central bool + user bool + fallback bool + }{ + { + name: "NoneEnabled", + central: false, + user: false, + fallback: false, + }, + { + name: "FallbackWithoutCentral", + central: false, + user: true, + fallback: true, + }, + { + name: "FallbackWithoutUser", + central: true, + user: false, + fallback: true, + }, + } + + for _, testCase := range testCases { + _, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + CentralAPIKeyEnabled: ptr.Ref(testCase.central), + AllowUserAPIKey: ptr.Ref(testCase.user), + AllowCentralAPIKeyFallback: ptr.Ref(testCase.fallback), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equalf(t, "Invalid credential policy.", sdkErr.Message, "case %s", testCase.name) + } + }) + + t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + APIKey: ptr.Ref(strings.Repeat("a", chatProviderAPIKeySizeLimit+1)), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key too large.", sdkErr.Message) + require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) + }) + + t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + updated, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + APIKey: ptr.Ref(strings.Repeat("a", chatProviderAPIKeySizeLimit)), + }) + require.NoError(t, err) + require.True(t, updated.HasAPIKey) + }) } func TestDeleteChatProvider(t *testing.T) { @@ -1467,6 +1993,471 @@ func TestDeleteChatProvider(t *testing.T) { }) } +func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { + t.Parallel() + + t.Run("DoesNotReuseBridgeConfig", func(t *testing.T) { + t.Parallel() + + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + values.AI.BridgeConfig.Anthropic.Key = serpent.String("deployment-anthropic-key") + values.AI.BridgeConfig.OpenAI.BaseURL = serpent.String("https://custom-openai.example.com") + + keys := coderd.ChatProviderAPIKeysFromDeploymentValues(values) + require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) + }) + + t.Run("NilDeploymentValues", func(t *testing.T) { + t.Parallel() + + keys := coderd.ChatProviderAPIKeysFromDeploymentValues(nil) + require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) + }) +} + +func TestUserChatProviderConfigs(t *testing.T) { + t.Parallel() + + requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig { + t.Helper() + + for _, config := range configs { + if config.Provider == provider { + return config + } + } + + t.Fatalf("provider %q not found", provider) + return codersdk.UserChatProviderConfig{} + } + + requireNoUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) { + t.Helper() + + for _, config := range configs { + if config.Provider == provider { + t.Fatalf("provider %q unexpectedly found", provider) + } + } + } + + t.Run("ListOnlyUserKeyProviders", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + anthropicProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "google", + APIKey: "central-api-key", + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + require.Len(t, configs, 1) + require.Equal(t, anthropicProvider.ID, configs[0].ProviderID) + require.Equal(t, anthropicProvider.Provider, configs[0].Provider) + }) + + t.Run("ListReportsHasUserAPIKeyFalse", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + require.Len(t, configs, 1) + require.Equal(t, provider.ID, configs[0].ProviderID) + require.False(t, configs[0].HasUserAPIKey) + }) + + t.Run("ListHidesDisabledProviderEvenWithSavedKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + require.NoError(t, err) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + Enabled: ptr.Ref(false), + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + require.Empty(t, configs) + requireNoUserProviderConfig(t, configs, "anthropic") + }) + + t.Run("ListHidesUserKeyDisabledProviderAndRestoresOnReEnable", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + require.NoError(t, err) + + centralAPIKey := "central-key" + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + APIKey: ¢ralAPIKey, + CentralAPIKeyEnabled: ptr.Ref(true), + AllowUserAPIKey: ptr.Ref(false), + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + require.Empty(t, configs) + requireNoUserProviderConfig(t, configs, "anthropic") + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + configs, err = client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "anthropic") + require.Equal(t, provider.ID, listed.ProviderID) + require.True(t, listed.HasUserAPIKey) + require.False(t, listed.HasCentralAPIKeyFallback) + }) + + t.Run("UpsertCreatesKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + APIKey: "central-key", + CentralAPIKeyEnabled: ptr.Ref(true), + AllowUserAPIKey: ptr.Ref(true), + AllowCentralAPIKeyFallback: ptr.Ref(true), + }) + require.NoError(t, err) + + config, err := client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + require.NoError(t, err) + require.Equal(t, provider.ID, config.ProviderID) + require.Equal(t, provider.Provider, config.Provider) + require.Equal(t, provider.DisplayName, config.DisplayName) + require.True(t, config.HasUserAPIKey) + require.True(t, config.HasCentralAPIKeyFallback) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "anthropic") + require.Equal(t, provider.ID, listed.ProviderID) + require.Equal(t, provider.DisplayName, listed.DisplayName) + require.True(t, listed.HasUserAPIKey) + require.True(t, listed.HasCentralAPIKeyFallback) + }) + + t.Run("ListRecomputesFallbackAvailability", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.OpenAI.Key = serpent.String("deployment-openai-key") + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-central-key", + AllowUserAPIKey: ptr.Ref(true), + AllowCentralAPIKeyFallback: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "openai") + require.True(t, listed.HasCentralAPIKeyFallback) + + _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + CentralAPIKeyEnabled: ptr.Ref(false), + AllowCentralAPIKeyFallback: ptr.Ref(false), + }) + require.NoError(t, err) + + configs, err = client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed = requireUserProviderConfig(t, configs, "openai") + require.False(t, listed.HasCentralAPIKeyFallback) + }) + + t.Run("UpsertUpdatesKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "key-1", + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "key-2", + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "anthropic") + require.True(t, listed.HasUserAPIKey) + }) + + t.Run("UpsertRejectsMissingProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.UpsertUserChatProviderKey(ctx, uuid.New(), codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("UpsertRejectsDisabledProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + Enabled: ptr.Ref(false), + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Provider is disabled.", sdkErr.Message) + }) + + t.Run("UpsertRejectsProviderWithoutUserKeys", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "google", + APIKey: "central-api-key", + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Provider does not allow user API keys.", sdkErr.Message) + }) + + t.Run("UpsertRejectsEmptyAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "", + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required.", sdkErr.Message) + }) + + t.Run("DeleteRemovesKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "user-key", + }) + require.NoError(t, err) + + configs, err := client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "anthropic") + require.True(t, listed.HasUserAPIKey) + + err = client.DeleteUserChatProviderKey(ctx, provider.ID) + require.NoError(t, err) + + configs, err = client.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed = requireUserProviderConfig(t, configs, "anthropic") + require.False(t, listed.HasUserAPIKey) + + err = client.DeleteUserChatProviderKey(ctx, provider.ID) + require.NoError(t, err) + }) + + t.Run("OtherUserDoesNotSeeKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + + provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = adminClient.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: "admin-user-key", + }) + require.NoError(t, err) + + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + configs, err := memberClient.ListUserChatProviderConfigs(ctx) + require.NoError(t, err) + listed := requireUserProviderConfig(t, configs, "anthropic") + require.Equal(t, provider.ID, listed.ProviderID) + require.False(t, listed.HasUserAPIKey) + }) +} + +func TestUpsertUserChatProviderKey(t *testing.T) { + t.Parallel() + + t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit+1), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key too large.", sdkErr.Message) + require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) + }) + + t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "anthropic", + CentralAPIKeyEnabled: ptr.Ref(false), + AllowUserAPIKey: ptr.Ref(true), + }) + require.NoError(t, err) + + config, err := client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit), + }) + require.NoError(t, err) + require.True(t, config.HasUserAPIKey) + }) +} + func TestListChatModelConfigs(t *testing.T) { t.Parallel() @@ -3943,12 +4934,11 @@ func TestRegenerateChatTitle(t *testing.T) { ) require.NoError(t, err) defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusConflict, res.StatusCode) - var resp codersdk.Chat + var resp codersdk.Response require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) - require.Equal(t, chat.ID, resp.ID) - require.Equal(t, "pending chat without worker", resp.Title) + require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message) persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) diff --git a/coderd/httpmw/chatparam_test.go b/coderd/httpmw/chatparam_test.go index 52c97e0c5b..8585b9462c 100644 --- a/coderd/httpmw/chatparam_test.go +++ b/coderd/httpmw/chatparam_test.go @@ -39,13 +39,14 @@ func TestChatParam(t *testing.T) { t.Helper() _, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-api-key", - BaseUrl: "https://api.openai.com/v1", - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true}, - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-api-key", + BaseUrl: "https://api.openai.com/v1", + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 0b3b34fe33..3f739f33dd 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -1585,17 +1585,17 @@ func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) e if err != nil { return xerrors.Errorf("lock chat for manual title regeneration: %w", err) } - if isFreshManualTitleLock(lockedChat, now) { + // Only a fresh manual lock or a chat without a real worker should + // block title regeneration. Running chats with a real worker may + // regenerate their title concurrently, and last write wins. + hasRealWorker := lockedChat.Status == database.ChatStatusRunning && + lockedChat.WorkerID.Valid && + lockedChat.WorkerID.UUID != manualTitleLockWorkerID + if lockedChat.Status == database.ChatStatusPending || + (lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) || + isFreshManualTitleLock(lockedChat, now) { return ErrManualTitleRegenerationInProgress } - - // Only write the lock marker when no real worker owns WorkerID. - // When a real worker is running, we skip the DB lock but still - // allow regeneration. The frontend prevents same-browser - // double-clicks, and concurrent regeneration from different - // replicas is harmless, last write wins. - hasRealWorker := lockedChat.WorkerID.Valid && - lockedChat.WorkerID.UUID != manualTitleLockWorkerID if hasRealWorker { return nil } @@ -1658,7 +1658,7 @@ func (p *Server) RegenerateChatTitle( // keeping chat ownership authorization at the HTTP layer. //nolint:gocritic // Non-admin users need chatd-scoped config reads here. chatdCtx := dbauthz.AsChatd(ctx) - keys, err := p.resolveProviderAPIKeys(chatdCtx) + keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID) if err != nil { return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err) } @@ -4808,7 +4808,7 @@ func (p *Server) resolveChatModel( }) g.Go(func() error { var err error - keys, err = p.resolveProviderAPIKeys(ctx) + keys, err = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID) if err != nil { return xerrors.Errorf("resolve provider API keys: %w", err) } @@ -4830,8 +4830,9 @@ func (p *Server) resolveChatModel( return model, dbConfig, keys, nil } -func (p *Server) resolveProviderAPIKeys( +func (p *Server) resolveUserProviderAPIKeys( ctx context.Context, + ownerID uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) { providers, err := p.configCache.EnabledProviders(ctx) if err != nil { @@ -4840,17 +4841,62 @@ func (p *Server) resolveProviderAPIKeys( err, ) } - dbProviders := make( + configuredProviders := make( []chatprovider.ConfiguredProvider, 0, len(providers), ) for _, provider := range providers { - dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{ - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }) + configuredProviders = append( + configuredProviders, chatprovider.ConfiguredProvider{ + ProviderID: provider.ID, + Provider: provider.Provider, + APIKey: provider.APIKey, + BaseURL: provider.BaseUrl, + CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserAPIKey: provider.AllowUserApiKey, + AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, + }, + ) } - return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil + allowAnyUserAPIKey := false + for _, provider := range configuredProviders { + if provider.AllowUserAPIKey { + allowAnyUserAPIKey = true + break + } + } + + userKeys := []chatprovider.UserProviderKey{} + if allowAnyUserAPIKey { + userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID) + if err != nil { + return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( + "get user chat provider keys: %w", + err, + ) + } + userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) + for _, userKey := range userKeyRows { + userKeys = append(userKeys, chatprovider.UserProviderKey{ + ChatProviderID: userKey.ChatProviderID, + APIKey: userKey.APIKey, + }) + } + } + keys, _ := chatprovider.ResolveUserProviderKeys( + p.providerAPIKeys, + configuredProviders, + userKeys, + ) + enabledProviders := make(map[string]struct{}, len(configuredProviders)) + for _, provider := range configuredProviders { + normalizedProvider := chatprovider.NormalizeProvider(provider.Provider) + if normalizedProvider == "" { + continue + } + enabledProviders[normalizedProvider] = struct{}{} + } + chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders) + return keys, nil } // resolveModelConfig looks up the chat's model config by its diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index cac56b2ebd..4382bec01e 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -23,6 +23,7 @@ import ( dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/codersdk" @@ -99,9 +100,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - APIKey: "test-key", - BaseUrl: serverURL, + Provider: "openai", + CentralApiKeyEnabled: true, + APIKey: "test-key", + BaseUrl: serverURL, }}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( @@ -261,9 +263,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - APIKey: "test-key", - BaseUrl: serverURL, + Provider: "openai", + CentralApiKeyEnabled: true, + APIKey: "test-key", + BaseUrl: serverURL, }}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( @@ -378,6 +381,87 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t } } +func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + + server := &Server{ + db: db, + configCache: newChatConfigCache( + context.Background(), + db, + quartz.NewReal(), + ), + providerAPIKeys: chatprovider.ProviderAPIKeys{ + OpenAI: "openai-deployment-key", + Anthropic: "anthropic-deployment-key", + ByProvider: map[string]string{ + "openai": "openai-deployment-key", + "anthropic": "anthropic-deployment-key", + }, + BaseURLByProvider: map[string]string{ + "openai": "https://openai.example.com", + "anthropic": "https://anthropic.example.com", + }, + }, + } + + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ + Provider: "anthropic", + CentralApiKeyEnabled: true, + AllowCentralApiKeyFallback: true, + }}, nil) + + keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID) + require.NoError(t, err) + require.Empty(t, keys.OpenAI) + require.Empty(t, keys.APIKey("openai")) + require.Empty(t, keys.BaseURL("openai")) + require.Equal(t, "anthropic-deployment-key", keys.Anthropic) + require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic")) + require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic")) + require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider) + require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider) +} + +func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + + server := &Server{ + db: db, + configCache: newChatConfigCache( + context.Background(), + db, + quartz.NewReal(), + ), + providerAPIKeys: chatprovider.ProviderAPIKeys{ + OpenAI: "openai-deployment-key", + ByProvider: map[string]string{ + "openai": "openai-deployment-key", + }, + }, + } + + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ + Provider: "openai", + CentralApiKeyEnabled: true, + }}, nil) + + keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID) + require.NoError(t, err) + require.Equal(t, "openai-deployment-key", keys.OpenAI) + require.Equal(t, "openai-deployment-key", keys.APIKey("openai")) +} + func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) { t.Parallel() @@ -523,7 +607,8 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { workspacesdk.LSResponse{}, codersdk.NewTestError(404, "POST", "/api/v0/list-directory"), ).AnyTimes() - conn.EXPECT().ReadFile(gomock.Any(), + conn.EXPECT().ReadFile( + gomock.Any(), "/home/coder/project/AGENTS.md", int64(0), int64(maxInstructionFileBytes+1)).Return( diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 00a1f49f50..5fc7a439df 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -2893,12 +2893,13 @@ func seedChatDependenciesWithProvider( user := dbgen.User(t, db, database.User{}) _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: provider, - DisplayName: provider, - APIKey: "test-key", - BaseUrl: baseURL, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: provider, + DisplayName: provider, + APIKey: "test-key", + BaseUrl: baseURL, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ @@ -2917,6 +2918,102 @@ func seedChatDependenciesWithProvider( return user, model } +func seedChatDependenciesWithProviderPolicy( + ctx context.Context, + t *testing.T, + db database.Store, + provider string, + baseURL string, + apiKey string, + centralAPIKeyEnabled bool, + allowUserAPIKey bool, + allowCentralAPIKeyFallback bool, +) (database.User, database.ChatProvider, database.ChatModelConfig) { + t.Helper() + + user := dbgen.User(t, db, database.User{}) + providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: provider, + DisplayName: provider, + APIKey: apiKey, + BaseUrl: baseURL, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: centralAPIKeyEnabled, + AllowUserApiKey: allowUserAPIKey, + AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, + }) + require.NoError(t, err) + + model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + Provider: provider, + Model: "gpt-4o-mini", + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 70, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + + return user, providerConfig, model +} + +func waitForTerminalChatStatusEvent( + ctx context.Context, + t *testing.T, + events <-chan codersdk.ChatStreamEvent, +) codersdk.ChatStatus { + t.Helper() + + var terminalStatus codersdk.ChatStatus + testutil.Eventually(ctx, t, func(context.Context) bool { + for { + select { + case event, ok := <-events: + if !ok { + return false + } + if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil { + continue + } + if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError { + terminalStatus = event.Status.Status + return true + } + default: + return false + } + } + }, testutil.IntervalFast) + + return terminalStatus +} + +func waitForTerminalChat( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, +) database.Chat { + t.Helper() + + var chatResult database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + got, err := db.GetChatByID(ctx, chatID) + if err != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.IntervalFast) + + return chatResult +} + // seedWorkspaceWithAgent creates a full workspace chain with a connected // agent. This is the common setup needed by tests that exercise tool // execution against a workspace. @@ -2973,12 +3070,15 @@ func setOpenAIProviderBaseURL( require.NoError(t, err) _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, + ID: provider.ID, + DisplayName: provider.DisplayName, + APIKey: provider.APIKey, + BaseUrl: baseURL, + ApiKeyKeyID: provider.ApiKeyKeyID, + Enabled: provider.Enabled, + CentralApiKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserApiKey: provider.AllowUserApiKey, + AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, }) require.NoError(t, err) } @@ -3552,12 +3652,13 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { // Add an Anthropic provider pointing to our mock server. _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "anthropic", - DisplayName: "Anthropic", - APIKey: "test-anthropic-key", - BaseUrl: anthropicSrv.URL, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: "anthropic", + DisplayName: "Anthropic", + APIKey: "test-anthropic-key", + BaseUrl: anthropicSrv.URL, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -3841,6 +3942,135 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) { "partial assistant response should contain the streamed text") } +func TestProcessChat_UserProviderKey_Success(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + const userAPIKey = "user-test-key" + + var authHeadersMu sync.Mutex + authHeaders := make([]string, 0, 1) + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + authHeadersMu.Lock() + authHeaders = append(authHeaders, req.Header.Get("Authorization")) + authHeadersMu.Unlock() + + if !req.Stream { + return chattest.OpenAINonStreamingResponse("user provider key success") + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("hello from the saved user key")..., + ) + }) + + user, provider, model := seedChatDependenciesWithProviderPolicy( + ctx, + t, + db, + "openai-compat", + openAIURL, + "", + false, + true, + false, + ) + _, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ + UserID: user.ID, + ChatProviderID: provider.ID, + APIKey: userAPIKey, + }) + require.NoError(t, err) + + creator := newTestServer(t, db, ps, uuid.New()) + chat, err := creator.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "user-provider-key-success", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("say hello"), + }, + }) + require.NoError(t, err) + + _, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + _ = newActiveTestServer(t, db, ps) + + terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) + require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus) + + chatResult := waitForTerminalChat(ctx, t, db, chat.ID) + require.Equal(t, database.ChatStatusWaiting, chatResult.Status) + require.False(t, chatResult.LastError.Valid) + + authHeadersMu.Lock() + recordedAuthHeaders := append([]string(nil), authHeaders...) + authHeadersMu.Unlock() + require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey) +} + +func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var llmCalls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + llmCalls.Add(1) + if !req.Stream { + return chattest.OpenAINonStreamingResponse("unexpected non-streaming request") + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("unexpected streaming request")..., + ) + }) + + user, _, model := seedChatDependenciesWithProviderPolicy( + ctx, + t, + db, + "openai-compat", + openAIURL, + "", + false, + true, + false, + ) + + creator := newTestServer(t, db, ps, uuid.New()) + chat, err := creator.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "user-provider-key-missing", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("say hello"), + }, + }) + require.NoError(t, err) + + _, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + _ = newActiveTestServer(t, db, ps) + + terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) + require.Equal(t, codersdk.ChatStatusError, terminalStatus) + + chatResult := waitForTerminalChat(ctx, t, db, chat.ID) + require.Equal(t, database.ChatStatusError, chatResult.Status) + require.True(t, chatResult.LastError.Valid, "LastError should be set") + require.NotEmpty(t, chatResult.LastError.String) + require.NotContains(t, chatResult.LastError.String, "panicked") + require.NotEqual(t, database.ChatStatusRunning, chatResult.Status) + require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request") +} + func TestProcessChatPanicRecovery(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatprompt/chatprompt_test.go b/coderd/x/chatd/chatprompt/chatprompt_test.go index 703d1c3d7a..766baca98f 100644 --- a/coderd/x/chatd/chatprompt/chatprompt_test.go +++ b/coderd/x/chatd/chatprompt/chatprompt_test.go @@ -1459,11 +1459,12 @@ func TestNulEscapeRoundTrip(t *testing.T) { user := dbgen.User(t, db, database.User{}) _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "openai", - APIKey: "test-key", - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: "openai", + DisplayName: "openai", + APIKey: "test-key", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) @@ -1943,11 +1944,12 @@ func TestMediaToolResultRoundTrip(t *testing.T) { user := dbgen.User(t, db, database.User{}) _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "anthropic", - DisplayName: "anthropic", - APIKey: "test-key", - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: "anthropic", + DisplayName: "anthropic", + APIKey: "test-key", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index 59164a8cbb..0c3f84cf0b 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -81,11 +81,28 @@ type ProviderAPIKeys struct { BaseURLByProvider map[string]string } +// UserProviderKey is a user-supplied API key for a specific provider. +type UserProviderKey struct { + ChatProviderID uuid.UUID + APIKey string +} + +// ProviderAvailability describes whether a provider has a usable +// API key and, if not, why. +type ProviderAvailability struct { + Available bool + UnavailableReason codersdk.ChatModelProviderUnavailableReason +} + // ConfiguredProvider is an enabled provider loaded from database config. type ConfiguredProvider struct { - Provider string - APIKey string - BaseURL string + ProviderID uuid.UUID + Provider string + APIKey string + BaseURL string + CentralAPIKeyEnabled bool + AllowUserAPIKey bool + AllowCentralAPIKeyFallback bool } // ConfiguredModel is an enabled model loaded from database config. @@ -189,21 +206,146 @@ func MergeProviderAPIKeys(fallback ProviderAPIKeys, providers []ConfiguredProvid return merged } -type ModelCatalog struct { - keys ProviderAPIKeys +// ResolveUserProviderKeys computes effective API keys and per-provider +// availability for a given user. It considers the provider's credential +// policy flags alongside central (DB/deployment) keys and the user's +// personal keys. +func ResolveUserProviderKeys( + fallback ProviderAPIKeys, + providers []ConfiguredProvider, + userKeys []UserProviderKey, +) (ProviderAPIKeys, map[string]ProviderAvailability) { + merged := ProviderAPIKeys{ + OpenAI: strings.TrimSpace(fallback.OpenAI), + Anthropic: strings.TrimSpace(fallback.Anthropic), + ByProvider: map[string]string{}, + BaseURLByProvider: map[string]string{}, + } + for provider, apiKey := range fallback.ByProvider { + normalizedProvider := NormalizeProvider(provider) + if normalizedProvider == "" { + continue + } + if key := strings.TrimSpace(apiKey); key != "" { + merged.ByProvider[normalizedProvider] = key + } + } + for provider, baseURL := range fallback.BaseURLByProvider { + normalizedProvider := NormalizeProvider(provider) + if normalizedProvider == "" { + continue + } + if url := strings.TrimSpace(baseURL); url != "" { + merged.BaseURLByProvider[normalizedProvider] = url + } + } + if merged.OpenAI != "" { + merged.ByProvider[fantasyopenai.Name] = merged.OpenAI + } + if merged.Anthropic != "" { + merged.ByProvider[fantasyanthropic.Name] = merged.Anthropic + } + + userKeyByProviderID := make(map[uuid.UUID]string, len(userKeys)) + for _, userKey := range userKeys { + if userKey.ChatProviderID == uuid.Nil { + continue + } + if key := strings.TrimSpace(userKey.APIKey); key != "" { + userKeyByProviderID[userKey.ChatProviderID] = key + } + } + + availabilityByProvider := make(map[string]ProviderAvailability, len(providers)) + for _, provider := range providers { + normalizedProvider := NormalizeProvider(provider.Provider) + if normalizedProvider == "" { + continue + } + + if url := strings.TrimSpace(provider.BaseURL); url != "" { + merged.BaseURLByProvider[normalizedProvider] = url + } + + var userKey string + if provider.ProviderID != uuid.Nil { + userKey = userKeyByProviderID[provider.ProviderID] + } + + var centralKey string + if provider.CentralAPIKeyEnabled { + if key := strings.TrimSpace(provider.APIKey); key != "" { + centralKey = key + } else { + centralKey = fallback.APIKey(normalizedProvider) + } + } + + resolved := ProviderAvailability{} + chosenKey := "" + switch { + case provider.AllowUserAPIKey && userKey != "": + chosenKey = userKey + resolved.Available = true + case centralKey != "": + if !provider.AllowUserAPIKey || provider.AllowCentralAPIKeyFallback { + chosenKey = centralKey + resolved.Available = true + } else { + resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired + } + case provider.AllowUserAPIKey && provider.AllowCentralAPIKeyFallback && provider.CentralAPIKeyEnabled: + // When users can add their own key, a missing central fallback key is + // still something the user can remedy. + resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired + case provider.AllowUserAPIKey: + resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired + default: + resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey + } + + setResolvedProviderAPIKey(&merged, normalizedProvider, chosenKey) + availabilityByProvider[normalizedProvider] = resolved + } + + return merged, availabilityByProvider } -func NewModelCatalog(keys ProviderAPIKeys) *ModelCatalog { - return &ModelCatalog{ - keys: keys, +func setResolvedProviderAPIKey(keys *ProviderAPIKeys, provider string, apiKey string) { + normalizedProvider := NormalizeProvider(provider) + if normalizedProvider == "" { + return } + if keys.ByProvider == nil { + keys.ByProvider = map[string]string{} + } + + delete(keys.ByProvider, normalizedProvider) + trimmedKey := strings.TrimSpace(apiKey) + switch normalizedProvider { + case fantasyopenai.Name: + keys.OpenAI = trimmedKey + case fantasyanthropic.Name: + keys.Anthropic = trimmedKey + } + if trimmedKey != "" { + keys.ByProvider[normalizedProvider] = trimmedKey + } +} + +type ModelCatalog struct{} + +func NewModelCatalog() *ModelCatalog { + return &ModelCatalog{} } // ListConfiguredModels returns a model catalog from enabled DB-backed model // configs. The second return value reports whether DB-backed models were used. -func (c *ModelCatalog) ListConfiguredModels( +func (*ModelCatalog) ListConfiguredModels( configuredProviders []ConfiguredProvider, configuredModels []ConfiguredModel, + availabilityByProvider map[string]ProviderAvailability, + enabledProviders map[string]struct{}, ) (codersdk.ChatModelsResponse, bool) { if len(configuredModels) == 0 { return codersdk.ChatModelsResponse{}, false @@ -247,11 +389,14 @@ func (c *ModelCatalog) ListConfiguredModels( return codersdk.ChatModelsResponse{}, false } - keys := MergeProviderAPIKeys(c.keys, configuredProviders) response := codersdk.ChatModelsResponse{ Providers: make([]codersdk.ChatModelProvider, 0, len(providers)), } for _, provider := range providers { + if _, ok := enabledProviders[provider]; !ok { + continue + } + models := modelsByProvider[provider] sortChatModels(models) @@ -259,11 +404,14 @@ func (c *ModelCatalog) ListConfiguredModels( Provider: provider, Models: models, } - if keys.APIKey(provider) == "" { + if avail, ok := availabilityByProvider[provider]; ok { + result.Available = avail.Available + if !avail.Available { + result.UnavailableReason = avail.UnavailableReason + } + } else { result.Available = false result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey - } else { - result.Available = true } response.Providers = append(response.Providers, result) @@ -273,25 +421,32 @@ func (c *ModelCatalog) ListConfiguredModels( } // ListConfiguredProviderAvailability returns provider availability derived from -// deployment/env keys merged with enabled DB provider keys. -func (c *ModelCatalog) ListConfiguredProviderAvailability( - configuredProviders []ConfiguredProvider, +// the policy-aware availability map for enabled providers. +func (*ModelCatalog) ListConfiguredProviderAvailability( + availabilityByProvider map[string]ProviderAvailability, + enabledProviders map[string]struct{}, ) codersdk.ChatModelsResponse { - keys := MergeProviderAPIKeys(c.keys, configuredProviders) response := codersdk.ChatModelsResponse{ Providers: make([]codersdk.ChatModelProvider, 0, len(supportedProviderNames)), } for _, provider := range supportedProviderNames { + if _, ok := enabledProviders[provider]; !ok { + continue + } + result := codersdk.ChatModelProvider{ Provider: provider, Models: []codersdk.ChatModel{}, } - if keys.APIKey(provider) == "" { + if avail, ok := availabilityByProvider[provider]; ok { + result.Available = avail.Available + if !avail.Available { + result.UnavailableReason = avail.UnavailableReason + } + } else { result.Available = false result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey - } else { - result.Available = true } response.Providers = append(response.Providers, result) @@ -300,6 +455,27 @@ func (c *ModelCatalog) ListConfiguredProviderAvailability( return response } +// PruneDisabledProviderKeys removes entries from keys that do not +// belong to an enabled provider. It clears ByProvider and +// BaseURLByProvider entries for disabled providers and zeroes the +// legacy OpenAI and Anthropic fields when those providers are not +// enabled. +func PruneDisabledProviderKeys(keys *ProviderAPIKeys, enabledProviders map[string]struct{}) { + for provider := range keys.ByProvider { + if _, ok := enabledProviders[provider]; ok { + continue + } + delete(keys.ByProvider, provider) + delete(keys.BaseURLByProvider, provider) + } + if _, ok := enabledProviders[NormalizeProvider("openai")]; !ok { + keys.OpenAI = "" + } + if _, ok := enabledProviders[NormalizeProvider("anthropic")]; !ok { + keys.Anthropic = "" + } +} + func newChatModel(provider, modelID, displayName string) codersdk.ChatModel { name := strings.TrimSpace(displayName) if name == "" { diff --git a/coderd/x/chatd/chatprovider/chatprovider_test.go b/coderd/x/chatd/chatprovider/chatprovider_test.go index f312ddc9fd..fa2a6a1e03 100644 --- a/coderd/x/chatd/chatprovider/chatprovider_test.go +++ b/coderd/x/chatd/chatprovider/chatprovider_test.go @@ -21,6 +21,166 @@ import ( "github.com/coder/coder/v2/testutil" ) +func TestResolveUserProviderKeys(t *testing.T) { + t.Parallel() + + configuredProvider := func(id uuid.UUID, provider string, centralEnabled bool, centralKey string, allowUser bool, allowCentralFallback bool) chatprovider.ConfiguredProvider { + return chatprovider.ConfiguredProvider{ + ProviderID: id, + Provider: provider, + APIKey: centralKey, + CentralAPIKeyEnabled: centralEnabled, + AllowUserAPIKey: allowUser, + AllowCentralAPIKeyFallback: allowCentralFallback, + } + } + + userProviderKey := func(id uuid.UUID, apiKey string) chatprovider.UserProviderKey { + return chatprovider.UserProviderKey{ + ChatProviderID: id, + APIKey: apiKey, + } + } + + openAIProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000001") + anthropicProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000002") + + tests := []struct { + name string + fallback chatprovider.ProviderAPIKeys + providers []chatprovider.ConfiguredProvider + userKeys []chatprovider.UserProviderKey + wantAvailability map[string]chatprovider.ProviderAvailability + wantKeys map[string]string + }{ + { + name: "CentralOnlyKeyPresent", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-central", + }, + }, + { + name: "CentralOnlyKeyMissing", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", false, false)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "", + }, + }, + { + name: "UserOnlyUserHasKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)}, + userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-user", + }, + }, + { + name: "UserOnlyUserHasNoKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "", + }, + }, + { + name: "BothEnabledFallbackOffUserHasKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)}, + userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-user", + }, + }, + { + name: "BothEnabledFallbackOffUserHasNoKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "", + }, + }, + { + name: "BothEnabledFallbackOnUserHasKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)}, + userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-user", + }, + }, + { + name: "BothEnabledFallbackOnUserHasNoKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-central", + }, + }, + { + name: "BothEnabledFallbackOnCentralKeyEmptyUserHasNoKey", + providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", true, true)}, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "", + }, + }, + { + name: "MultipleProvidersDifferentPolicies", + providers: []chatprovider.ConfiguredProvider{ + configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false), + configuredProvider(anthropicProviderID, fantasyanthropic.Name, false, "", true, false), + }, + wantAvailability: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: {Available: true}, + fantasyanthropic.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired}, + }, + wantKeys: map[string]string{ + fantasyopenai.Name: "sk-central", + fantasyanthropic.Name: "", + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + keys, availability := chatprovider.ResolveUserProviderKeys(tt.fallback, tt.providers, tt.userKeys) + + require.Len(t, availability, len(tt.wantAvailability)) + for provider, wantAvailability := range tt.wantAvailability { + gotAvailability, ok := availability[provider] + require.True(t, ok, "expected availability for provider %q", provider) + require.Equal(t, wantAvailability, gotAvailability) + require.Equal(t, tt.wantKeys[provider], keys.APIKey(provider)) + } + }) + } +} + func TestReasoningEffortFromChat(t *testing.T) { t.Parallel() @@ -91,6 +251,413 @@ func TestReasoningEffortFromChat(t *testing.T) { } } +func TestResolveUserProviderKeys_UnavailableReason(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + provider chatprovider.ConfiguredProvider + wantReason codersdk.ChatModelProviderUnavailableReason + }{ + { + name: "FallbackConfiguredWithoutCentralKeyReturnsUserAPIKeyRequired", + provider: chatprovider.ConfiguredProvider{ + Provider: "anthropic", + CentralAPIKeyEnabled: true, + AllowUserAPIKey: true, + AllowCentralAPIKeyFallback: true, + }, + wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + }, + { + name: "UserKeyRequiredWithoutFallback", + provider: chatprovider.ConfiguredProvider{ + Provider: "anthropic", + CentralAPIKeyEnabled: true, + AllowUserAPIKey: true, + }, + wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + keys, availability := chatprovider.ResolveUserProviderKeys( + chatprovider.ProviderAPIKeys{}, + []chatprovider.ConfiguredProvider{tt.provider}, + nil, + ) + + require.Empty(t, keys.APIKey(tt.provider.Provider)) + resolved, ok := availability[tt.provider.Provider] + require.True(t, ok) + require.False(t, resolved.Available) + require.Equal(t, tt.wantReason, resolved.UnavailableReason) + }) + } +} + +func TestListConfiguredModels_PolicyAwareAvailability(t *testing.T) { + t.Parallel() + + configuredProvider := func(provider string, apiKey string) chatprovider.ConfiguredProvider { + return chatprovider.ConfiguredProvider{ + ProviderID: uuid.New(), + Provider: provider, + APIKey: apiKey, + } + } + enabledProviders := func(providers ...string) map[string]struct{} { + result := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + result[chatprovider.NormalizeProvider(provider)] = struct{}{} + } + return result + } + + catalog := chatprovider.NewModelCatalog() + tests := []struct { + name string + configuredProviders []chatprovider.ConfiguredProvider + configuredModels []chatprovider.ConfiguredModel + availabilityByProvider map[string]chatprovider.ProviderAvailability + enabledProviders map[string]struct{} + want codersdk.ChatModelsResponse + }{ + { + name: "PolicyUnavailableOverridesConfiguredKey", + configuredProviders: []chatprovider.ConfiguredProvider{ + configuredProvider(fantasyopenai.Name, "sk-central"), + }, + configuredModels: []chatprovider.ConfiguredModel{{ + Provider: fantasyopenai.Name, + Model: "gpt-4", + }}, + availabilityByProvider: map[string]chatprovider.ProviderAvailability{ + fantasyopenai.Name: { + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + }, + }, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyopenai.Name, + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + Models: []codersdk.ChatModel{{ + ID: fantasyopenai.Name + ":gpt-4", + Provider: fantasyopenai.Name, + Model: "gpt-4", + DisplayName: "gpt-4", + }}, + }}}, + }, + { + name: "PolicyAvailableMarksProviderAvailable", + configuredProviders: []chatprovider.ConfiguredProvider{ + configuredProvider(fantasyanthropic.Name, "sk-central"), + }, + configuredModels: []chatprovider.ConfiguredModel{{ + Provider: fantasyanthropic.Name, + Model: "claude-3-5-sonnet", + }}, + availabilityByProvider: map[string]chatprovider.ProviderAvailability{ + fantasyanthropic.Name: {Available: true}, + }, + enabledProviders: enabledProviders(fantasyanthropic.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyanthropic.Name, + Available: true, + Models: []codersdk.ChatModel{{ + ID: fantasyanthropic.Name + ":claude-3-5-sonnet", + Provider: fantasyanthropic.Name, + Model: "claude-3-5-sonnet", + DisplayName: "claude-3-5-sonnet", + }}, + }}}, + }, + { + name: "DisabledProviderOmitted", + configuredProviders: []chatprovider.ConfiguredProvider{ + configuredProvider(fantasyanthropic.Name, "sk-anthropic"), + configuredProvider(fantasyopenai.Name, "sk-openai"), + }, + configuredModels: []chatprovider.ConfiguredModel{ + {Provider: fantasyanthropic.Name, Model: "claude-3-5-sonnet"}, + {Provider: fantasyopenai.Name, Model: "gpt-4"}, + }, + availabilityByProvider: map[string]chatprovider.ProviderAvailability{ + fantasyanthropic.Name: {Available: true}, + fantasyopenai.Name: {Available: true}, + }, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyopenai.Name, + Available: true, + Models: []codersdk.ChatModel{{ + ID: fantasyopenai.Name + ":gpt-4", + Provider: fantasyopenai.Name, + Model: "gpt-4", + DisplayName: "gpt-4", + }}, + }}}, + }, + { + name: "MissingAvailabilityDefaultsToMissingAPIKey", + configuredProviders: []chatprovider.ConfiguredProvider{ + configuredProvider(fantasyopenai.Name, "sk-central"), + }, + configuredModels: []chatprovider.ConfiguredModel{{ + Provider: fantasyopenai.Name, + Model: "gpt-4o", + }}, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyopenai.Name, + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey, + Models: []codersdk.ChatModel{{ + ID: fantasyopenai.Name + ":gpt-4o", + Provider: fantasyopenai.Name, + Model: "gpt-4o", + DisplayName: "gpt-4o", + }}, + }}}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, ok := catalog.ListConfiguredModels( + tt.configuredProviders, + tt.configuredModels, + tt.availabilityByProvider, + tt.enabledProviders, + ) + require.True(t, ok) + require.Equal(t, tt.want, got) + }) + } +} + +func TestListConfiguredProviderAvailability_PolicyAwareFiltering(t *testing.T) { + t.Parallel() + + enabledProviders := func(providers ...string) map[string]struct{} { + result := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + result[chatprovider.NormalizeProvider(provider)] = struct{}{} + } + return result + } + + catalog := chatprovider.NewModelCatalog() + tests := []struct { + name string + availabilityByProvider map[string]chatprovider.ProviderAvailability + enabledProviders map[string]struct{} + want codersdk.ChatModelsResponse + }{ + { + name: "EnabledProvidersUsePolicyAvailability", + availabilityByProvider: map[string]chatprovider.ProviderAvailability{ + fantasyanthropic.Name: { + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + }, + fantasyopenai.Name: {Available: true}, + }, + enabledProviders: enabledProviders(fantasyanthropic.Name, fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{ + { + Provider: fantasyanthropic.Name, + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, + Models: []codersdk.ChatModel{}, + }, + { + Provider: fantasyopenai.Name, + Available: true, + Models: []codersdk.ChatModel{}, + }, + }}, + }, + { + name: "DisabledSupportedProviderOmitted", + availabilityByProvider: map[string]chatprovider.ProviderAvailability{ + fantasyanthropic.Name: {Available: true}, + fantasyopenai.Name: {Available: true}, + }, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyopenai.Name, + Available: true, + Models: []codersdk.ChatModel{}, + }}}, + }, + { + name: "MissingAvailabilityDefaultsToMissingAPIKey", + enabledProviders: enabledProviders(fantasyopenai.Name), + want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{ + Provider: fantasyopenai.Name, + Available: false, + UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey, + Models: []codersdk.ChatModel{}, + }}}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := catalog.ListConfiguredProviderAvailability( + tt.availabilityByProvider, + tt.enabledProviders, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestPruneDisabledProviderKeys(t *testing.T) { + t.Parallel() + + enabledProviders := func(providers ...string) map[string]struct{} { + result := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + result[chatprovider.NormalizeProvider(provider)] = struct{}{} + } + return result + } + + tests := []struct { + name string + keys chatprovider.ProviderAPIKeys + enabledProviders map[string]struct{} + want chatprovider.ProviderAPIKeys + }{ + { + name: "DisabledProviderEntriesRemoved", + keys: chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyanthropic.Name: "sk-anthropic", + fantasyopenai.Name: "sk-openai", + }, + BaseURLByProvider: map[string]string{ + fantasyanthropic.Name: "https://anthropic.example.com", + fantasyopenai.Name: "https://openai.example.com", + }, + }, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + }, + }, + }, + { + name: "OpenAIDisabledClearsLegacyField", + keys: chatprovider.ProviderAPIKeys{ + OpenAI: "sk-openai", + Anthropic: "sk-anthropic", + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + fantasyanthropic.Name: "sk-anthropic", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + fantasyanthropic.Name: "https://anthropic.example.com", + }, + }, + enabledProviders: enabledProviders(fantasyanthropic.Name), + want: chatprovider.ProviderAPIKeys{ + Anthropic: "sk-anthropic", + ByProvider: map[string]string{ + fantasyanthropic.Name: "sk-anthropic", + }, + BaseURLByProvider: map[string]string{ + fantasyanthropic.Name: "https://anthropic.example.com", + }, + }, + }, + { + name: "AnthropicDisabledClearsLegacyField", + keys: chatprovider.ProviderAPIKeys{ + OpenAI: "sk-openai", + Anthropic: "sk-anthropic", + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + fantasyanthropic.Name: "sk-anthropic", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + fantasyanthropic.Name: "https://anthropic.example.com", + }, + }, + enabledProviders: enabledProviders(fantasyopenai.Name), + want: chatprovider.ProviderAPIKeys{ + OpenAI: "sk-openai", + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + }, + }, + }, + { + name: "AllEnabledLeavesKeysUnchanged", + keys: chatprovider.ProviderAPIKeys{ + OpenAI: "sk-openai", + Anthropic: "sk-anthropic", + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + fantasyanthropic.Name: "sk-anthropic", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + fantasyanthropic.Name: "https://anthropic.example.com", + }, + }, + enabledProviders: enabledProviders(fantasyopenai.Name, fantasyanthropic.Name), + want: chatprovider.ProviderAPIKeys{ + OpenAI: "sk-openai", + Anthropic: "sk-anthropic", + ByProvider: map[string]string{ + fantasyopenai.Name: "sk-openai", + fantasyanthropic.Name: "sk-anthropic", + }, + BaseURLByProvider: map[string]string{ + fantasyopenai.Name: "https://openai.example.com", + fantasyanthropic.Name: "https://anthropic.example.com", + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + keys := tt.keys + chatprovider.PruneDisabledProviderKeys(&keys, tt.enabledProviders) + require.Equal(t, tt.want, keys) + }) + } +} + func TestCoderHeaders(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chattool/startworkspace_test.go b/coderd/x/chatd/chattool/startworkspace_test.go index 10c230837e..119e2deafe 100644 --- a/coderd/x/chatd/chattool/startworkspace_test.go +++ b/coderd/x/chatd/chattool/startworkspace_test.go @@ -440,13 +440,14 @@ func seedModelConfig( t.Helper() _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - BaseUrl: "", - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + BaseUrl: "", + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 7c0645457d..cbb01f9f73 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -122,13 +122,14 @@ func seedInternalChatDeps( user := dbgen.User(t, db, database.User{}) _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - BaseUrl: "", - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + BaseUrl: "", + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/coderd/x/gitsync/worker_test.go b/coderd/x/gitsync/worker_test.go index 2b406336f6..ce89a3f77f 100644 --- a/coderd/x/gitsync/worker_test.go +++ b/coderd/x/gitsync/worker_test.go @@ -959,9 +959,10 @@ func TestWorker(t *testing.T) { // 3. Set up FK chain: chat_providers -> chat_model_configs -> chats. _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + Enabled: true, + CentralApiKeyEnabled: true, }) require.NoError(t, err) diff --git a/codersdk/chats.go b/codersdk/chats.go index 15ad632dd6..41f1ea3d1f 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -405,6 +405,8 @@ type ChatModelProviderUnavailableReason string const ( ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key" ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed" + // #nosec G101 + ChatModelProviderUnavailableReasonUserAPIKeyRequired ChatModelProviderUnavailableReason = "user_api_key_required" ) // ChatModel represents a model in the chat model catalog. @@ -532,32 +534,57 @@ const ( // ChatProviderConfig is an admin-managed provider configuration. type ChatProviderConfig struct { - ID uuid.UUID `json:"id" format:"uuid"` - Provider string `json:"provider"` - DisplayName string `json:"display_name"` - Enabled bool `json:"enabled"` - HasAPIKey bool `json:"has_api_key"` - BaseURL string `json:"base_url,omitempty"` - Source ChatProviderConfigSource `json:"source"` - CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"` - UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"` + ID uuid.UUID `json:"id" format:"uuid"` + Provider string `json:"provider"` + DisplayName string `json:"display_name"` + Enabled bool `json:"enabled"` + HasAPIKey bool `json:"has_api_key"` + CentralAPIKeyEnabled bool `json:"central_api_key_enabled"` + AllowUserAPIKey bool `json:"allow_user_api_key"` + AllowCentralAPIKeyFallback bool `json:"allow_central_api_key_fallback"` + BaseURL string `json:"base_url,omitempty"` + Source ChatProviderConfigSource `json:"source"` + CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"` + UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"` } // CreateChatProviderConfigRequest creates a chat provider config. type CreateChatProviderConfigRequest struct { - Provider string `json:"provider"` - DisplayName string `json:"display_name,omitempty"` - APIKey string `json:"api_key,omitempty"` - BaseURL string `json:"base_url,omitempty"` - Enabled *bool `json:"enabled,omitempty"` + Provider string `json:"provider"` + DisplayName string `json:"display_name,omitempty"` + APIKey string `json:"api_key,omitempty"` + BaseURL string `json:"base_url,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"` + AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"` + AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"` } // UpdateChatProviderConfigRequest updates a chat provider config. type UpdateChatProviderConfigRequest struct { - DisplayName string `json:"display_name,omitempty"` - APIKey *string `json:"api_key,omitempty"` - BaseURL *string `json:"base_url,omitempty"` - Enabled *bool `json:"enabled,omitempty"` + DisplayName string `json:"display_name,omitempty"` + APIKey *string `json:"api_key,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"` + AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"` + AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"` +} + +// UserChatProviderConfig is a summary of a provider that allows +// user-supplied keys, as seen from the current user's perspective. +type UserChatProviderConfig struct { + ProviderID uuid.UUID `json:"provider_id" format:"uuid"` + Provider string `json:"provider"` + DisplayName string `json:"display_name"` + HasUserAPIKey bool `json:"has_user_api_key"` + HasCentralAPIKeyFallback bool `json:"has_central_api_key_fallback"` +} + +// CreateUserChatProviderKeyRequest creates or replaces a user's API key +// for a provider. +type CreateUserChatProviderKeyRequest struct { + APIKey string `json:"api_key"` } // ChatModelConfig is an admin-managed model configuration. @@ -1332,6 +1359,47 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID return nil } +// ListUserChatProviderConfigs returns user-scoped chat provider configs. +func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil) + if err != nil { + return nil, xerrors.Errorf("list user chat provider configs: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + var configs []UserChatProviderConfig + return configs, json.NewDecoder(res.Body).Decode(&configs) +} + +// UpsertUserChatProviderKey creates or replaces a user API key for a provider. +func (c *ExperimentalClient) UpsertUserChatProviderKey(ctx context.Context, providerID uuid.UUID, req CreateUserChatProviderKeyRequest) (UserChatProviderConfig, error) { + res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), req) + if err != nil { + return UserChatProviderConfig{}, xerrors.Errorf("upsert user chat provider key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return UserChatProviderConfig{}, ReadBodyAsError(res) + } + var config UserChatProviderConfig + return config, json.NewDecoder(res.Body).Decode(&config) +} + +// DeleteUserChatProviderKey deletes a user API key for a provider. +func (c *ExperimentalClient) DeleteUserChatProviderKey(ctx context.Context, providerID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), nil) + if err != nil { + return xerrors.Errorf("delete user chat provider key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // ListChatModelConfigs returns admin-managed chat model configs. func (c *ExperimentalClient) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) { res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/model-configs", nil) diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index 7a0512da2d..d1861fd561 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -104,13 +104,14 @@ func seedChatDependencies( user := dbgen.User(t, db, database.User{}) _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - BaseUrl: safetyNet.URL, - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - Enabled: true, + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + BaseUrl: safetyNet.URL, + CentralApiKeyEnabled: true, + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, }) require.NoError(t, err) model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ @@ -186,12 +187,15 @@ func setOpenAIProviderBaseURL( require.NoError(t, err) _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, + ID: provider.ID, + DisplayName: provider.DisplayName, + APIKey: provider.APIKey, + BaseUrl: baseURL, + CentralApiKeyEnabled: true, + AllowUserApiKey: false, + AllowCentralApiKeyFallback: false, + ApiKeyKeyID: provider.ApiKeyKeyID, + Enabled: provider.Enabled, }) require.NoError(t, err) } diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index c435bb1b6c..803c3d643d 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -73,12 +73,35 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err) } } + + userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid) + if err != nil { + return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) + } + for _, userProviderKey := range userProviderKeys { + if strings.TrimSpace(userProviderKey.APIKey) == "" { + continue + } + if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() { + log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + continue + } + if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ + UserID: userProviderKey.UserID, + ChatProviderID: userProviderKey.ChatProviderID, + APIKey: userProviderKey.APIKey, + ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required + }); err != nil { + return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) + } + log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + } return nil }, &database.TxOptions{ Isolation: sql.LevelRepeatableRead, }) if err != nil { - return xerrors.Errorf("update user links: %w", err) + return xerrors.Errorf("update user tokens and chat provider keys: %w", err) } log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } @@ -97,12 +120,15 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe continue } if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required - Enabled: provider.Enabled, - ID: provider.ID, + DisplayName: provider.DisplayName, + APIKey: provider.APIKey, + BaseUrl: provider.BaseUrl, + ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required + Enabled: provider.Enabled, + CentralApiKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserApiKey: provider.AllowUserApiKey, + AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, + ID: provider.ID, }); err != nil { return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) } @@ -189,12 +215,32 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err) } } + + userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid) + if err != nil { + return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) + } + for _, userProviderKey := range userProviderKeys { + if !userProviderKey.ApiKeyKeyID.Valid { + log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) + continue + } + if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ + UserID: userProviderKey.UserID, + ChatProviderID: userProviderKey.ChatProviderID, + APIKey: userProviderKey.APIKey, + ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id + }); err != nil { + return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) + } + log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) + } return nil }, &database.TxOptions{ Isolation: sql.LevelRepeatableRead, }) if err != nil { - return xerrors.Errorf("update user links: %w", err) + return xerrors.Errorf("update user tokens and chat provider keys: %w", err) } log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } @@ -209,12 +255,15 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph continue } if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id - Enabled: provider.Enabled, - ID: provider.ID, + DisplayName: provider.DisplayName, + APIKey: provider.APIKey, + BaseUrl: provider.BaseUrl, + ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id + Enabled: provider.Enabled, + CentralApiKeyEnabled: provider.CentralApiKeyEnabled, + AllowUserApiKey: provider.AllowUserApiKey, + AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, + ID: provider.ID, }); err != nil { return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) } @@ -241,6 +290,8 @@ DELETE FROM user_links DELETE FROM external_auth_links WHERE oauth_access_token_key_id IS NOT NULL OR oauth_refresh_token_key_id IS NOT NULL; +DELETE FROM user_chat_provider_keys + WHERE api_key_key_id IS NOT NULL; UPDATE chat_providers SET api_key = '', api_key_key_id = NULL diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 3c5f957786..6779b02b8f 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -471,6 +471,57 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat return provider, nil } +func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error { + return db.decryptField(&key.APIKey, key.ApiKeyKeyID) +} + +func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { + keys, err := db.Store.GetUserChatProviderKeys(ctx, userID) + if err != nil { + return nil, err + } + for i := range keys { + if err := db.decryptUserChatProviderKey(&keys[i]); err != nil { + return nil, err + } + } + return keys, nil +} + +func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + if strings.TrimSpace(params.APIKey) == "" { + params.ApiKeyKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { + return database.UserChatProviderKey{}, err + } + + key, err := db.Store.UpsertUserChatProviderKey(ctx, params) + if err != nil { + return database.UserChatProviderKey{}, err + } + if err := db.decryptUserChatProviderKey(&key); err != nil { + return database.UserChatProviderKey{}, err + } + return key, nil +} + +func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { + if strings.TrimSpace(params.APIKey) == "" { + params.ApiKeyKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { + return database.UserChatProviderKey{}, err + } + + key, err := db.Store.UpdateUserChatProviderKey(ctx, params) + if err != nil { + return database.UserChatProviderKey{}, err + } + if err := db.decryptUserChatProviderKey(&key); err != nil { + return database.UserChatProviderKey{}, err + } + return key, nil +} + // decryptMCPServerConfig decrypts all encrypted fields on a // single MCPServerConfig in place. func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index d664987a56..eb82fad168 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1177,3 +1177,113 @@ func TestMCPServerUserTokens(t *testing.T) { requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken) }) } + +func TestUserChatProviderKeys(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const ( + //nolint:gosec // test credentials + initialAPIKey = "sk-initial-api-key-value" + //nolint:gosec // test credentials + updatedAPIKey = "sk-updated-api-key-value" + ) + + insertProviderAndKey := func( + t *testing.T, + crypt *dbCrypt, + ciphers []Cipher, + ) (database.ChatProvider, database.UserChatProviderKey) { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + provider, err := crypt.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "", + Enabled: true, + AllowUserApiKey: true, + }) + require.NoError(t, err) + + key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ + UserID: user.ID, + ChatProviderID: provider.ID, + APIKey: initialAPIKey, + }) + require.NoError(t, err) + require.Equal(t, initialAPIKey, key.APIKey) + require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String) + return provider, key + } + + getUserChatProviderKey := func(t *testing.T, store interface { + GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error) + }, userID uuid.UUID, providerID uuid.UUID, + ) database.UserChatProviderKey { + t.Helper() + keys, err := store.GetUserChatProviderKeys(ctx, userID) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, providerID, keys[0].ChatProviderID) + return keys[0] + } + + t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) + require.Equal(t, key.ID, got.ID) + require.Equal(t, initialAPIKey, got.APIKey) + require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) + + rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, initialAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey) + }) + + t.Run("GetUserChatProviderKeys", func(t *testing.T) { + t.Parallel() + _, crypt, ciphers := setup(t) + _, key := insertProviderAndKey(t, crypt, ciphers) + + keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, key.ID, keys[0].ID) + require.Equal(t, initialAPIKey, keys[0].APIKey) + require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) + }) + + t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ + UserID: key.UserID, + ChatProviderID: provider.ID, + APIKey: updatedAPIKey, + }) + require.NoError(t, err) + require.Equal(t, key.ID, updated.ID) + require.Equal(t, key.CreatedAt, updated.CreatedAt) + require.False(t, updated.UpdatedAt.Before(key.UpdatedAt)) + require.Equal(t, updatedAPIKey, updated.APIKey) + require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) + + got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) + require.Equal(t, updatedAPIKey, got.APIKey) + require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) + + keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, updatedAPIKey, keys[0].APIKey) + + rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, updatedAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey) + }) +} diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 1d1d14afe7..cc8c56d182 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1776,10 +1776,11 @@ export interface ChatModelProviderOptions { // From codersdk/chats.go export type ChatModelProviderUnavailableReason = | "fetch_failed" - | "missing_api_key"; + | "missing_api_key" + | "user_api_key_required"; export const ChatModelProviderUnavailableReasons: ChatModelProviderUnavailableReason[] = - ["fetch_failed", "missing_api_key"]; + ["fetch_failed", "missing_api_key", "user_api_key_required"]; // From codersdk/chats.go /** @@ -1836,6 +1837,9 @@ export interface ChatProviderConfig { readonly display_name: string; readonly enabled: boolean; readonly has_api_key: boolean; + readonly central_api_key_enabled: boolean; + readonly allow_user_api_key: boolean; + readonly allow_central_api_key_fallback: boolean; readonly base_url?: string; readonly source: ChatProviderConfigSource; readonly created_at?: string; @@ -2362,6 +2366,9 @@ export interface CreateChatProviderConfigRequest { readonly api_key?: string; readonly base_url?: string; readonly enabled?: boolean; + readonly central_api_key_enabled?: boolean; + readonly allow_user_api_key?: boolean; + readonly allow_central_api_key_fallback?: boolean; } // From codersdk/chats.go @@ -2643,6 +2650,15 @@ export interface CreateTokenRequest { readonly allow_list?: readonly APIAllowListTarget[]; } +// From codersdk/chats.go +/** + * CreateUserChatProviderKeyRequest creates or replaces a user's API key + * for a provider. + */ +export interface CreateUserChatProviderKeyRequest { + readonly api_key: string; +} + // From codersdk/users.go export interface CreateUserRequestWithOrgs { readonly email: string; @@ -7150,6 +7166,9 @@ export interface UpdateChatProviderConfigRequest { readonly api_key?: string; readonly base_url?: string; readonly enabled?: boolean; + readonly central_api_key_enabled?: boolean; + readonly allow_user_api_key?: boolean; + readonly allow_central_api_key_fallback?: boolean; } // From codersdk/chats.go @@ -7694,6 +7713,19 @@ export interface UserChatCustomPrompt { readonly custom_prompt: string; } +// From codersdk/chats.go +/** + * UserChatProviderConfig is a summary of a provider that allows + * user-supplied keys, as seen from the current user's perspective. + */ +export interface UserChatProviderConfig { + readonly provider_id: string; + readonly provider: string; + readonly display_name: string; + readonly has_user_api_key: boolean; + readonly has_central_api_key_fallback: boolean; +} + // From codersdk/insights.go /** * UserLatency shows the connection latency for a user. diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx index dec9fda203..78845f668d 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx @@ -20,6 +20,10 @@ const createProviderConfig = ( display_name: overrides.display_name ?? "", enabled: overrides.enabled ?? true, has_api_key: overrides.has_api_key ?? false, + central_api_key_enabled: overrides.central_api_key_enabled ?? true, + allow_user_api_key: overrides.allow_user_api_key ?? false, + allow_central_api_key_fallback: + overrides.allow_central_api_key_fallback ?? false, base_url: overrides.base_url ?? "", source: overrides.source ?? "database", created_at: overrides.created_at ?? now, diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.stories.tsx index 4515b269b9..8d6451282b 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.stories.tsx @@ -14,6 +14,9 @@ const providerState: ProviderState = { display_name: "OpenAI", enabled: true, has_api_key: true, + central_api_key_enabled: true, + allow_user_api_key: false, + allow_central_api_key_fallback: false, base_url: undefined, source: "database", created_at: "2025-01-01T00:00:00Z",