feat: provider key policies and user provider settings (#23751)

This commit is contained in:
Michael Suchacz
2026-04-02 19:46:42 +02:00
committed by GitHub
parent 17dec2a70f
commit 7d0a0c6495
39 changed files with 3551 additions and 266 deletions
+3
View File
@@ -110,6 +110,9 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
annotations entirely. Do not add `@Summary`, `@Router`, or other
swagger comments to handlers in that file.
### Database Query Naming
+8 -1
View File
@@ -782,7 +782,7 @@ func New(options *Options) *API {
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
CreateWorkspace: api.chatCreateWorkspace,
@@ -1221,6 +1221,13 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/user-provider-configs", func(r chi.Router) {
r.Get("/", api.listUserChatProviderConfigs)
r.Route("/{providerConfig}", func(r chi.Router) {
r.Put("/", api.upsertUserChatProviderKey)
r.Delete("/", api.deleteUserChatProviderKey)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
+2
View File
@@ -10,6 +10,7 @@ const (
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
@@ -32,4 +33,5 @@ const (
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
)
+44
View File
@@ -2137,6 +2137,17 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
}
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
@@ -4024,6 +4035,17 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return nil, err
}
return q.db.GetUserChatProviderKeys(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
@@ -6454,6 +6476,17 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
}
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpdateUserChatProviderKey(ctx, arg)
}
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
}
@@ -7181,6 +7214,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
return q.db.UpsertTemplateUsageStats(ctx)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return database.UserChatProviderKey{}, err
}
return q.db.UpsertUserChatProviderKey(ctx, arg)
}
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+30
View File
@@ -2407,6 +2407,36 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
}))
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
}))
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
}))
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
+32
View File
@@ -696,6 +696,14 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
return r0
}
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
start := time.Now()
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
return r0
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
@@ -2528,6 +2536,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
@@ -4560,6 +4576,14 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.UpdateUserDeletedByID(ctx, id)
@@ -5152,6 +5176,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
start := time.Now()
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
+59
View File
@@ -1171,6 +1171,20 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
}
// DeleteUserChatProviderKey mocks base method.
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -4729,6 +4743,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatProviderKeys mocks base method.
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
ret0, _ := ret[0].([]database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
@@ -8605,6 +8634,21 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
}
// UpdateUserChatProviderKey mocks base method.
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
}
// UpdateUserDeletedByID mocks base method.
func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -9671,6 +9715,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
}
// UpsertUserChatProviderKey mocks base method.
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
ret0, _ := ret[0].(database.UserChatProviderKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
}
// UpsertWebpushVAPIDKeys mocks base method.
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
m.ctrl.T.Helper()
+31 -1
View File
@@ -1341,7 +1341,11 @@ CREATE TABLE chat_providers (
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
base_url text DEFAULT ''::text NOT NULL,
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
central_api_key_enabled boolean DEFAULT true NOT NULL,
allow_user_api_key boolean DEFAULT false NOT NULL,
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
);
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
@@ -2752,6 +2756,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of
COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.';
CREATE TABLE user_chat_provider_keys (
id uuid DEFAULT gen_random_uuid() NOT NULL,
user_id uuid NOT NULL,
chat_provider_id uuid NOT NULL,
api_key text NOT NULL,
api_key_key_id text,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
);
CREATE TABLE user_configs (
user_id uuid NOT NULL,
key character varying(256) NOT NULL,
@@ -3548,6 +3563,12 @@ ALTER TABLE ONLY usage_events_daily
ALTER TABLE ONLY usage_events
ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
@@ -4258,6 +4279,15 @@ ALTER TABLE ONLY templates
ALTER TABLE ONLY templates
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_chat_provider_keys
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_configs
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -92,6 +92,9 @@ const (
ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE;
ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT;
ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -0,0 +1,8 @@
DROP TABLE IF EXISTS user_chat_provider_keys;
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
ALTER TABLE chat_providers
DROP COLUMN IF EXISTS central_api_key_enabled,
DROP COLUMN IF EXISTS allow_user_api_key,
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
@@ -0,0 +1,24 @@
ALTER TABLE chat_providers
ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE,
ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE chat_providers
ADD CONSTRAINT valid_credential_policy CHECK (
(central_api_key_enabled OR allow_user_api_key) AND
(
NOT allow_central_api_key_fallback OR
(central_api_key_enabled AND allow_user_api_key)
)
);
CREATE TABLE user_chat_provider_keys (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE,
api_key TEXT NOT NULL CHECK (api_key != ''),
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (user_id, chat_provider_id)
);
@@ -0,0 +1,16 @@
INSERT INTO user_chat_provider_keys (
user_id,
chat_provider_id,
api_key,
created_at,
updated_at
)
SELECT
id,
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
'fixture-test-key',
'2025-01-01 00:00:00+00',
'2025-01-01 00:00:00+00'
FROM users
ORDER BY created_at, id
LIMIT 1;
+19 -6
View File
@@ -4264,12 +4264,15 @@ type ChatProvider struct {
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
BaseUrl string `db:"base_url" json:"base_url"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
}
type ChatQueuedMessage struct {
@@ -5222,6 +5225,16 @@ type User struct {
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
}
type UserChatProviderKey struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type UserConfig struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Key string `db:"key" json:"key"`
+4
View File
@@ -150,6 +150,7 @@ type sqlcQuerier interface {
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
@@ -577,6 +578,7 @@ type sqlcQuerier interface {
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
// Returns the minimum (most restrictive) group limit for a user.
@@ -927,6 +929,7 @@ type sqlcQuerier interface {
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error
@@ -1015,6 +1018,7 @@ type sqlcQuerier interface {
// used to store the data, and the minutes are summed for each user and template
// combination. The result is stored in the template_usage_stats table.
UpsertTemplateUsageStats(ctx context.Context) error
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
+35 -28
View File
@@ -1261,10 +1261,11 @@ func TestGetAuthorizedChats(t *testing.T) {
// Create FK dependencies: a chat provider and model config.
ctx := testutil.Context(t, testutil.WaitMedium)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -9456,10 +9457,11 @@ func TestInsertChatMessages(t *testing.T) {
provider := "openai"
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: provider,
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -9621,10 +9623,11 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
// A chat_providers row is required as a FK for model configs.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -9992,10 +9995,11 @@ func TestGetPRInsights(t *testing.T) {
user := dbgen.User(t, store, database.User{})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-key",
Enabled: true,
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -10516,10 +10520,11 @@ func TestChatPinOrderQueries(t *testing.T) {
// timed test context doesn't tick during DB init.
bg := context.Background()
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -10696,10 +10701,11 @@ func TestChatLabels(t *testing.T) {
owner := dbgen.User(t, db, database.User{})
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -10907,10 +10913,11 @@ func TestChatHasUnread(t *testing.T) {
user := dbgen.User(t, store, database.User{})
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+181 -22
View File
@@ -3799,7 +3799,7 @@ func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) e
const getChatProviderByID = `-- name: GetChatProviderByID :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
@@ -3820,13 +3820,16 @@ func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (Cha
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
@@ -3847,13 +3850,16 @@ func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider str
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
const getChatProviders = `-- name: GetChatProviders :many
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
ORDER BY
@@ -3880,6 +3886,9 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
); err != nil {
return nil, err
}
@@ -3896,7 +3905,7 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro
const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many
SELECT
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
FROM
chat_providers
WHERE
@@ -3925,6 +3934,9 @@ func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvide
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
); err != nil {
return nil, err
}
@@ -3947,7 +3959,10 @@ INSERT INTO chat_providers (
base_url,
api_key_key_id,
created_by,
enabled
enabled,
central_api_key_enabled,
allow_user_api_key,
allow_central_api_key_fallback
) VALUES (
$1::text,
$2::text,
@@ -3955,20 +3970,26 @@ INSERT INTO chat_providers (
$4::text,
$5::text,
$6::uuid,
$7::boolean
$7::boolean,
$8::boolean,
$9::boolean,
$10::boolean
)
RETURNING
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
`
type InsertChatProviderParams struct {
Provider string `db:"provider" json:"provider"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
Provider string `db:"provider" json:"provider"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
Enabled bool `db:"enabled" json:"enabled"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
}
func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) {
@@ -3980,6 +4001,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi
arg.ApiKeyKeyID,
arg.CreatedBy,
arg.Enabled,
arg.CentralApiKeyEnabled,
arg.AllowUserApiKey,
arg.AllowCentralApiKeyFallback,
)
var i ChatProvider
err := row.Scan(
@@ -3993,6 +4017,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
@@ -4006,20 +4033,26 @@ SET
base_url = $3::text,
api_key_key_id = $4::text,
enabled = $5::boolean,
central_api_key_enabled = $6::boolean,
allow_user_api_key = $7::boolean,
allow_central_api_key_fallback = $8::boolean,
updated_at = NOW()
WHERE
id = $6::uuid
id = $9::uuid
RETURNING
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
`
type UpdateChatProviderParams struct {
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
Enabled bool `db:"enabled" json:"enabled"`
ID uuid.UUID `db:"id" json:"id"`
DisplayName string `db:"display_name" json:"display_name"`
APIKey string `db:"api_key" json:"api_key"`
BaseUrl string `db:"base_url" json:"base_url"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
Enabled bool `db:"enabled" json:"enabled"`
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) {
@@ -4029,6 +4062,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi
arg.BaseUrl,
arg.ApiKeyKeyID,
arg.Enabled,
arg.CentralApiKeyEnabled,
arg.AllowUserApiKey,
arg.AllowCentralApiKeyFallback,
arg.ID,
)
var i ChatProvider
@@ -4043,6 +4079,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi
&i.CreatedAt,
&i.UpdatedAt,
&i.BaseUrl,
&i.CentralApiKeyEnabled,
&i.AllowUserApiKey,
&i.AllowCentralApiKeyFallback,
)
return i, err
}
@@ -22617,6 +22656,126 @@ func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretP
return i, err
}
const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec
DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2
`
type DeleteUserChatProviderKeyParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
}
func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error {
_, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID)
return err
}
const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many
SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC
`
func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) {
rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserChatProviderKey
for rows.Next() {
var i UserChatProviderKey
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one
UPDATE user_chat_provider_keys
SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW()
WHERE user_id = $3 AND chat_provider_id = $4
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
`
type UpdateUserChatProviderKeyParams struct {
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
}
func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) {
row := q.db.QueryRowContext(ctx, updateUserChatProviderKey,
arg.APIKey,
arg.ApiKeyKeyID,
arg.UserID,
arg.ChatProviderID,
)
var i UserChatProviderKey
err := row.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
VALUES ($1, $2, $3, $4::text)
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
api_key = $3,
api_key_key_id = $4::text,
updated_at = NOW()
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
`
type UpsertUserChatProviderKeyParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
APIKey string `db:"api_key" json:"api_key"`
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
}
func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) {
row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey,
arg.UserID,
arg.ChatProviderID,
arg.APIKey,
arg.ApiKeyKeyID,
)
var i UserChatProviderKey
err := row.Scan(
&i.ID,
&i.UserID,
&i.ChatProviderID,
&i.APIKey,
&i.ApiKeyKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const allUserIDs = `-- name: AllUserIDs :many
SELECT DISTINCT id FROM USERS
WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END
+11 -2
View File
@@ -40,7 +40,10 @@ INSERT INTO chat_providers (
base_url,
api_key_key_id,
created_by,
enabled
enabled,
central_api_key_enabled,
allow_user_api_key,
allow_central_api_key_fallback
) VALUES (
@provider::text,
@display_name::text,
@@ -48,7 +51,10 @@ INSERT INTO chat_providers (
@base_url::text,
sqlc.narg('api_key_key_id')::text,
sqlc.narg('created_by')::uuid,
@enabled::boolean
@enabled::boolean,
@central_api_key_enabled::boolean,
@allow_user_api_key::boolean,
@allow_central_api_key_fallback::boolean
)
RETURNING
*;
@@ -62,6 +68,9 @@ SET
base_url = @base_url::text,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
enabled = @enabled::boolean,
central_api_key_enabled = @central_api_key_enabled::boolean,
allow_user_api_key = @allow_user_api_key::boolean,
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
updated_at = NOW()
WHERE
id = @id::uuid
@@ -0,0 +1,20 @@
-- name: GetUserChatProviderKeys :many
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
-- name: UpsertUserChatProviderKey :one
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
api_key = @api_key,
api_key_key_id = sqlc.narg('api_key_key_id')::text,
updated_at = NOW()
RETURNING *;
-- name: UpdateUserChatProviderKey :one
UPDATE user_chat_provider_keys
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
RETURNING *;
-- name: DeleteUserChatProviderKey :exec
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
+2
View File
@@ -90,6 +90,8 @@ const (
UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type);
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
+447 -57
View File
@@ -520,6 +520,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
//nolint:gocritic // System context required to read enabled chat models.
systemCtx := dbauthz.AsSystemRestricted(ctx)
@@ -555,14 +556,24 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
configuredProviders := make(
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
)
enabledProviderNames := make(map[string]struct{}, len(enabledProviders))
for _, provider := range enabledProviders {
configuredProviders = append(
configuredProviders, chatprovider.ConfiguredProvider{
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
ProviderID: provider.ID,
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserAPIKey: provider.AllowUserApiKey,
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
},
)
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
enabledProviderNames[normalizedProvider] = struct{}{}
}
configuredModels := make(
[]chatprovider.ConfiguredModel, 0, len(enabledModels),
@@ -575,18 +586,38 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
})
}
keys := chatprovider.MergeProviderAPIKeys(
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to load user chat provider keys.",
Detail: err.Error(),
})
return
}
userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
for _, userKey := range userKeyRows {
userKeys = append(userKeys, chatprovider.UserProviderKey{
ChatProviderID: userKey.ChatProviderID,
APIKey: userKey.APIKey,
})
}
_, providerAvailability := chatprovider.ResolveUserProviderKeys(
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
configuredProviders,
userKeys,
)
catalog := chatprovider.NewModelCatalog(keys)
catalog := chatprovider.NewModelCatalog()
var response codersdk.ChatModelsResponse
if configured, ok := catalog.ListConfiguredModels(
configuredProviders, configuredModels,
configuredProviders, configuredModels, providerAvailability, enabledProviderNames,
); ok {
response = configured
} else {
response = catalog.ListConfiguredProviderAvailability(configuredProviders)
response = catalog.ListConfiguredProviderAvailability(
providerAvailability,
enabledProviderNames,
)
}
httpapi.Write(ctx, rw, http.StatusOK, response)
@@ -3926,9 +3957,13 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
)
for _, provider := range enabledProviders {
normalizedProvider := normalizeChatProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
enabledConfiguredProviders = append(
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
Provider: provider.Provider,
Provider: normalizedProvider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
},
@@ -3936,7 +3971,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
}
effectiveKeys := chatprovider.MergeProviderAPIKeys(
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
enabledConfiguredProviders,
)
effectiveKeys = chatprovider.MergeProviderAPIKeys(
@@ -3952,7 +3987,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
resp,
convertChatProviderConfig(
configured,
effectiveKeys.APIKey(provider) != "",
api.hasEffectiveProviderAPIKey(ctx, configured),
codersdk.ChatProviderConfigSourceDatabase,
),
)
@@ -3968,13 +4003,16 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
}
resp = append(resp, codersdk.ChatProviderConfig{
ID: uuid.Nil,
Provider: provider,
DisplayName: chatprovider.ProviderDisplayName(provider),
Enabled: enabled,
HasAPIKey: hasAPIKey,
BaseURL: effectiveKeys.BaseURL(provider),
Source: source,
ID: uuid.Nil,
Provider: provider,
DisplayName: chatprovider.ProviderDisplayName(provider),
Enabled: enabled,
HasAPIKey: hasAPIKey,
CentralAPIKeyEnabled: true,
AllowUserAPIKey: false,
AllowCentralAPIKeyFallback: false,
BaseURL: effectiveKeys.BaseURL(provider),
Source: source,
})
}
@@ -3984,6 +4022,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
var inserted database.ChatProvider
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
@@ -4003,6 +4042,14 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
return
}
if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key too large.",
Detail: err.Error(),
})
return
}
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
@@ -4016,14 +4063,57 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
return
}
inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: strings.TrimSpace(req.DisplayName),
APIKey: strings.TrimSpace(req.APIKey),
BaseUrl: baseURL,
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
Enabled: enabled,
centralAPIKeyEnabled := true
if req.CentralAPIKeyEnabled != nil {
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
}
allowUserAPIKey := false
if req.AllowUserAPIKey != nil {
allowUserAPIKey = *req.AllowUserAPIKey
}
allowCentralAPIKeyFallback := false
if req.AllowCentralAPIKeyFallback != nil {
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
}
if err := validateChatProviderCredentialPolicy(
centralAPIKeyEnabled,
allowUserAPIKey,
allowCentralAPIKeyFallback,
); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid credential policy.",
Detail: err.Error(),
})
return
}
if err := validateChatProviderCentralAPIKey(
centralAPIKeyEnabled,
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
Provider: provider,
APIKey: strings.TrimSpace(req.APIKey),
BaseUrl: baseURL,
CentralApiKeyEnabled: centralAPIKeyEnabled,
}, uuid.Nil),
); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: strings.TrimSpace(req.DisplayName),
APIKey: strings.TrimSpace(req.APIKey),
BaseUrl: baseURL,
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
Enabled: enabled,
CentralApiKeyEnabled: centralAPIKeyEnabled,
AllowUserApiKey: allowUserAPIKey,
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
})
if err != nil {
switch {
@@ -4064,6 +4154,10 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var (
existing database.ChatProvider
updated database.ChatProvider
)
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
@@ -4105,7 +4199,17 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
apiKey := existing.APIKey
apiKeyKeyID := existing.ApiKeyKeyID
if req.APIKey != nil {
apiKey = strings.TrimSpace(*req.APIKey)
trimmedAPIKey := strings.TrimSpace(*req.APIKey)
if trimmedAPIKey != "" {
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key too large.",
Detail: err.Error(),
})
return
}
}
apiKey = trimmedAPIKey
apiKeyKeyID = sql.NullString{}
}
baseURL := existing.BaseUrl
@@ -4120,13 +4224,57 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
}
}
updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: displayName,
APIKey: apiKey,
BaseUrl: baseURL,
ApiKeyKeyID: apiKeyKeyID,
Enabled: enabled,
ID: existing.ID,
centralAPIKeyEnabled := existing.CentralApiKeyEnabled
if req.CentralAPIKeyEnabled != nil {
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
}
allowUserAPIKey := existing.AllowUserApiKey
if req.AllowUserAPIKey != nil {
allowUserAPIKey = *req.AllowUserAPIKey
}
allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback
if req.AllowCentralAPIKeyFallback != nil {
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
}
if err := validateChatProviderCredentialPolicy(
centralAPIKeyEnabled,
allowUserAPIKey,
allowCentralAPIKeyFallback,
); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid credential policy.",
Detail: err.Error(),
})
return
}
if err := validateChatProviderCentralAPIKey(
centralAPIKeyEnabled,
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
ID: existing.ID,
Provider: existing.Provider,
APIKey: apiKey,
BaseUrl: baseURL,
CentralApiKeyEnabled: centralAPIKeyEnabled,
}, existing.ID),
); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: displayName,
APIKey: apiKey,
BaseUrl: baseURL,
ApiKeyKeyID: apiKeyKeyID,
Enabled: enabled,
CentralApiKeyEnabled: centralAPIKeyEnabled,
AllowUserApiKey: allowUserAPIKey,
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
ID: existing.ID,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -4197,6 +4345,169 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
//nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials.
chatdCtx := dbauthz.AsChatd(ctx)
providers, err := api.Database.GetChatProviders(chatdCtx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list chat providers.",
Detail: err.Error(),
})
return
}
userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list user chat provider keys.",
Detail: err.Error(),
})
return
}
hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys))
for _, userKey := range userKeys {
hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true
}
resp := make([]codersdk.UserChatProviderConfig, 0, len(providers))
for _, provider := range providers {
if !provider.Enabled || !provider.AllowUserApiKey {
continue
}
hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID]
hasCentralAPIKeyFallback := provider.Enabled &&
provider.AllowCentralApiKeyFallback &&
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
resp = append(
resp,
convertUserChatProviderConfig(
provider,
hasUserAPIKey,
hasCentralAPIKeyFallback,
),
)
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
providerID, ok := parseChatProviderID(rw, r)
if !ok {
return
}
//nolint:gocritic // Non-admin users need to validate provider availability before storing their own key.
provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat provider.",
Detail: err.Error(),
})
return
}
if !provider.Enabled {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Provider is disabled.",
})
return
}
if !provider.AllowUserApiKey {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Provider does not allow user API keys.",
})
return
}
var req codersdk.CreateUserChatProviderKeyRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
trimmedAPIKey := strings.TrimSpace(req.APIKey)
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key too large.",
Detail: err.Error(),
})
return
}
if trimmedAPIKey == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key is required.",
})
return
}
if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: apiKey.UserID,
ChatProviderID: providerID,
APIKey: trimmedAPIKey,
ApiKeyKeyID: sql.NullString{},
}); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to save user chat provider key.",
Detail: err.Error(),
})
return
}
hasCentralAPIKeyFallback := provider.Enabled &&
provider.AllowCentralApiKeyFallback &&
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
httpapi.Write(
ctx,
rw,
http.StatusOK,
convertUserChatProviderConfig(
provider,
true,
hasCentralAPIKeyFallback,
),
)
}
func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
)
providerID, ok := parseChatProviderID(rw, r)
if !ok {
return
}
if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{
UserID: apiKey.UserID,
ChatProviderID: providerID,
}); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete user chat provider key.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -4737,15 +5048,37 @@ func convertChatProviderConfig(
}
return codersdk.ChatProviderConfig{
ID: provider.ID,
Provider: provider.Provider,
DisplayName: displayName,
Enabled: provider.Enabled,
HasAPIKey: hasAPIKey,
BaseURL: strings.TrimSpace(provider.BaseUrl),
Source: source,
CreatedAt: provider.CreatedAt,
UpdatedAt: provider.UpdatedAt,
ID: provider.ID,
Provider: provider.Provider,
DisplayName: displayName,
Enabled: provider.Enabled,
HasAPIKey: hasAPIKey,
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserAPIKey: provider.AllowUserApiKey,
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
BaseURL: strings.TrimSpace(provider.BaseUrl),
Source: source,
CreatedAt: provider.CreatedAt,
UpdatedAt: provider.UpdatedAt,
}
}
func convertUserChatProviderConfig(
provider database.ChatProvider,
hasUserAPIKey bool,
hasCentralAPIKeyFallback bool,
) codersdk.UserChatProviderConfig {
displayName := strings.TrimSpace(provider.DisplayName)
if displayName == "" {
displayName = chatprovider.ProviderDisplayName(provider.Provider)
}
return codersdk.UserChatProviderConfig{
ProviderID: provider.ID,
Provider: provider.Provider,
DisplayName: displayName,
HasUserAPIKey: hasUserAPIKey,
HasCentralAPIKeyFallback: hasCentralAPIKeyFallback,
}
}
@@ -4904,26 +5237,80 @@ func chatProviderValidationDetail() string {
return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "."
}
func chatProviderAPIKeysFromDeploymentValues(
deploymentValues *codersdk.DeploymentValues,
) chatprovider.ProviderAPIKeys {
_ = deploymentValues
// For now, we'll just manage configs in the UI.
// We should probably not be reusing the AI bridge configs anyways.
return chatprovider.ProviderAPIKeys{
// OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
// Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
// BaseURLByProvider: map[string]string{
// "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
// "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
// },
const maxChatProviderAPIKeySize = 10240 // 10 KB
func validateChatProviderAPIKeySize(apiKey string) error {
if len(apiKey) > maxChatProviderAPIKeySize {
return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize)
}
return nil
}
//nolint:revive // This helper validates the explicit credential policy tuple.
func validateChatProviderCredentialPolicy(
centralEnabled, allowUserKey, allowFallback bool,
) error {
if !centralEnabled && !allowUserKey {
return xerrors.New(
"At least one credential source must be enabled: central API key or user API key.",
)
}
if allowFallback && !centralEnabled {
return xerrors.New(
"Central API key fallback requires central API key to be enabled.",
)
}
if allowFallback && !allowUserKey {
return xerrors.New(
"Central API key fallback requires user API key to be enabled.",
)
}
return nil
}
//nolint:revive // This helper validates central-key requirements.
func validateChatProviderCentralAPIKey(
centralEnabled bool,
hasCentralAPIKey bool,
) error {
if centralEnabled && !hasCentralAPIKey {
return xerrors.New(
"API key is required when central API key is enabled.",
)
}
return nil
}
// ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat
// provider API keys.
func ChatProviderAPIKeysFromDeploymentValues(
_ *codersdk.DeploymentValues,
) chatprovider.ProviderAPIKeys {
// AI bridge deployment config is intentionally not reused for chat
// provider credentials. Bridge keys serve the AI task subsystem and
// should not silently broaden into chat execution paths.
return chatprovider.ProviderAPIKeys{}
}
func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool {
return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
}
func (api *API) hasEffectiveCentralProviderAPIKey(
ctx context.Context,
provider database.ChatProvider,
excludeProviderID uuid.UUID,
) bool {
if !provider.CentralApiKeyEnabled {
return false
}
if strings.TrimSpace(provider.APIKey) != "" {
return true
}
deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues)
if deploymentKeys.APIKey(provider.Provider) != "" {
return true
}
if api.chatDaemon == nil {
return false
}
@@ -4945,6 +5332,9 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
)
for _, configured := range enabledProviders {
if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID {
continue
}
enabledConfiguredProviders = append(
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
Provider: configured.Provider,
@@ -4955,7 +5345,7 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
}
effectiveKeys := chatprovider.MergeProviderAPIKeys(
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
deploymentKeys,
enabledConfiguredProviders,
)
return effectiveKeys.APIKey(provider.Provider) != ""
File diff suppressed because it is too large Load Diff
+8 -7
View File
@@ -39,13 +39,14 @@ func TestChatParam(t *testing.T) {
t.Helper()
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
BaseUrl: "https://api.openai.com/v1",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
BaseUrl: "https://api.openai.com/v1",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+65 -19
View File
@@ -1585,17 +1585,17 @@ func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) e
if err != nil {
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
}
if isFreshManualTitleLock(lockedChat, now) {
// Only a fresh manual lock or a chat without a real worker should
// block title regeneration. Running chats with a real worker may
// regenerate their title concurrently, and last write wins.
hasRealWorker := lockedChat.Status == database.ChatStatusRunning &&
lockedChat.WorkerID.Valid &&
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
if lockedChat.Status == database.ChatStatusPending ||
(lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) ||
isFreshManualTitleLock(lockedChat, now) {
return ErrManualTitleRegenerationInProgress
}
// Only write the lock marker when no real worker owns WorkerID.
// When a real worker is running, we skip the DB lock but still
// allow regeneration. The frontend prevents same-browser
// double-clicks, and concurrent regeneration from different
// replicas is harmless, last write wins.
hasRealWorker := lockedChat.WorkerID.Valid &&
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
if hasRealWorker {
return nil
}
@@ -1658,7 +1658,7 @@ func (p *Server) RegenerateChatTitle(
// keeping chat ownership authorization at the HTTP layer.
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
chatdCtx := dbauthz.AsChatd(ctx)
keys, err := p.resolveProviderAPIKeys(chatdCtx)
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID)
if err != nil {
return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err)
}
@@ -4808,7 +4808,7 @@ func (p *Server) resolveChatModel(
})
g.Go(func() error {
var err error
keys, err = p.resolveProviderAPIKeys(ctx)
keys, err = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID)
if err != nil {
return xerrors.Errorf("resolve provider API keys: %w", err)
}
@@ -4830,8 +4830,9 @@ func (p *Server) resolveChatModel(
return model, dbConfig, keys, nil
}
func (p *Server) resolveProviderAPIKeys(
func (p *Server) resolveUserProviderAPIKeys(
ctx context.Context,
ownerID uuid.UUID,
) (chatprovider.ProviderAPIKeys, error) {
providers, err := p.configCache.EnabledProviders(ctx)
if err != nil {
@@ -4840,17 +4841,62 @@ func (p *Server) resolveProviderAPIKeys(
err,
)
}
dbProviders := make(
configuredProviders := make(
[]chatprovider.ConfiguredProvider, 0, len(providers),
)
for _, provider := range providers {
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
})
configuredProviders = append(
configuredProviders, chatprovider.ConfiguredProvider{
ProviderID: provider.ID,
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserAPIKey: provider.AllowUserApiKey,
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
},
)
}
return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil
allowAnyUserAPIKey := false
for _, provider := range configuredProviders {
if provider.AllowUserAPIKey {
allowAnyUserAPIKey = true
break
}
}
userKeys := []chatprovider.UserProviderKey{}
if allowAnyUserAPIKey {
userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get user chat provider keys: %w",
err,
)
}
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
for _, userKey := range userKeyRows {
userKeys = append(userKeys, chatprovider.UserProviderKey{
ChatProviderID: userKey.ChatProviderID,
APIKey: userKey.APIKey,
})
}
}
keys, _ := chatprovider.ResolveUserProviderKeys(
p.providerAPIKeys,
configuredProviders,
userKeys,
)
enabledProviders := make(map[string]struct{}, len(configuredProviders))
for _, provider := range configuredProviders {
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
enabledProviders[normalizedProvider] = struct{}{}
}
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
return keys, nil
}
// resolveModelConfig looks up the chat's model config by its
+92 -7
View File
@@ -23,6 +23,7 @@ import (
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
@@ -99,9 +100,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
APIKey: "test-key",
BaseUrl: serverURL,
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
@@ -261,9 +263,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
APIKey: "test-key",
BaseUrl: serverURL,
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
@@ -378,6 +381,87 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
}
}
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
Anthropic: "anthropic-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
"anthropic": "anthropic-deployment-key",
},
BaseURLByProvider: map[string]string{
"openai": "https://openai.example.com",
"anthropic": "https://anthropic.example.com",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "anthropic",
CentralApiKeyEnabled: true,
AllowCentralApiKeyFallback: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Empty(t, keys.OpenAI)
require.Empty(t, keys.APIKey("openai"))
require.Empty(t, keys.BaseURL("openai"))
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
}
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Equal(t, "openai-deployment-key", keys.OpenAI)
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
}
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
@@ -523,7 +607,8 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).AnyTimes()
conn.EXPECT().ReadFile(gomock.Any(),
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1)).Return(
+248 -18
View File
@@ -2893,12 +2893,13 @@ func seedChatDependenciesWithProvider(
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: provider,
APIKey: "test-key",
BaseUrl: baseURL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: provider,
DisplayName: provider,
APIKey: "test-key",
BaseUrl: baseURL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
@@ -2917,6 +2918,102 @@ func seedChatDependenciesWithProvider(
return user, model
}
func seedChatDependenciesWithProviderPolicy(
ctx context.Context,
t *testing.T,
db database.Store,
provider string,
baseURL string,
apiKey string,
centralAPIKeyEnabled bool,
allowUserAPIKey bool,
allowCentralAPIKeyFallback bool,
) (database.User, database.ChatProvider, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: provider,
APIKey: apiKey,
BaseUrl: baseURL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: centralAPIKeyEnabled,
AllowUserApiKey: allowUserAPIKey,
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: provider,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, providerConfig, model
}
func waitForTerminalChatStatusEvent(
ctx context.Context,
t *testing.T,
events <-chan codersdk.ChatStreamEvent,
) codersdk.ChatStatus {
t.Helper()
var terminalStatus codersdk.ChatStatus
testutil.Eventually(ctx, t, func(context.Context) bool {
for {
select {
case event, ok := <-events:
if !ok {
return false
}
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
continue
}
if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError {
terminalStatus = event.Status.Status
return true
}
default:
return false
}
}
}, testutil.IntervalFast)
return terminalStatus
}
func waitForTerminalChat(
ctx context.Context,
t *testing.T,
db database.Store,
chatID uuid.UUID,
) database.Chat {
t.Helper()
var chatResult database.Chat
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
got, err := db.GetChatByID(ctx, chatID)
if err != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.IntervalFast)
return chatResult
}
// seedWorkspaceWithAgent creates a full workspace chain with a connected
// agent. This is the common setup needed by tests that exercise tool
// execution against a workspace.
@@ -2973,12 +3070,15 @@ func setOpenAIProviderBaseURL(
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
})
require.NoError(t, err)
}
@@ -3552,12 +3652,13 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
// Add an Anthropic provider pointing to our mock server.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-anthropic-key",
BaseUrl: anthropicSrv.URL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-anthropic-key",
BaseUrl: anthropicSrv.URL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -3841,6 +3942,135 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) {
"partial assistant response should contain the streamed text")
}
func TestProcessChat_UserProviderKey_Success(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
const userAPIKey = "user-test-key"
var authHeadersMu sync.Mutex
authHeaders := make([]string, 0, 1)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
authHeadersMu.Lock()
authHeaders = append(authHeaders, req.Header.Get("Authorization"))
authHeadersMu.Unlock()
if !req.Stream {
return chattest.OpenAINonStreamingResponse("user provider key success")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("hello from the saved user key")...,
)
})
user, provider, model := seedChatDependenciesWithProviderPolicy(
ctx,
t,
db,
"openai-compat",
openAIURL,
"",
false,
true,
false,
)
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: user.ID,
ChatProviderID: provider.ID,
APIKey: userAPIKey,
})
require.NoError(t, err)
creator := newTestServer(t, db, ps, uuid.New())
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "user-provider-key-success",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("say hello"),
},
})
require.NoError(t, err)
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
_ = newActiveTestServer(t, db, ps)
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
require.False(t, chatResult.LastError.Valid)
authHeadersMu.Lock()
recordedAuthHeaders := append([]string(nil), authHeaders...)
authHeadersMu.Unlock()
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
}
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var llmCalls atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
llmCalls.Add(1)
if !req.Stream {
return chattest.OpenAINonStreamingResponse("unexpected non-streaming request")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("unexpected streaming request")...,
)
})
user, _, model := seedChatDependenciesWithProviderPolicy(
ctx,
t,
db,
"openai-compat",
openAIURL,
"",
false,
true,
false,
)
creator := newTestServer(t, db, ps, uuid.New())
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "user-provider-key-missing",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("say hello"),
},
})
require.NoError(t, err)
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
_ = newActiveTestServer(t, db, ps)
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
require.Equal(t, codersdk.ChatStatusError, terminalStatus)
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
require.Equal(t, database.ChatStatusError, chatResult.Status)
require.True(t, chatResult.LastError.Valid, "LastError should be set")
require.NotEmpty(t, chatResult.LastError.String)
require.NotContains(t, chatResult.LastError.String, "panicked")
require.NotEqual(t, database.ChatStatusRunning, chatResult.Status)
require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request")
}
func TestProcessChatPanicRecovery(t *testing.T) {
t.Parallel()
+12 -10
View File
@@ -1459,11 +1459,12 @@ func TestNulEscapeRoundTrip(t *testing.T) {
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "openai",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: "openai",
DisplayName: "openai",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
@@ -1943,11 +1944,12 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "anthropic",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: "anthropic",
DisplayName: "anthropic",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+196 -20
View File
@@ -81,11 +81,28 @@ type ProviderAPIKeys struct {
BaseURLByProvider map[string]string
}
// UserProviderKey is a user-supplied API key for a specific provider.
type UserProviderKey struct {
ChatProviderID uuid.UUID
APIKey string
}
// ProviderAvailability describes whether a provider has a usable
// API key and, if not, why.
type ProviderAvailability struct {
Available bool
UnavailableReason codersdk.ChatModelProviderUnavailableReason
}
// ConfiguredProvider is an enabled provider loaded from database config.
type ConfiguredProvider struct {
Provider string
APIKey string
BaseURL string
ProviderID uuid.UUID
Provider string
APIKey string
BaseURL string
CentralAPIKeyEnabled bool
AllowUserAPIKey bool
AllowCentralAPIKeyFallback bool
}
// ConfiguredModel is an enabled model loaded from database config.
@@ -189,21 +206,146 @@ func MergeProviderAPIKeys(fallback ProviderAPIKeys, providers []ConfiguredProvid
return merged
}
type ModelCatalog struct {
keys ProviderAPIKeys
// ResolveUserProviderKeys computes effective API keys and per-provider
// availability for a given user. It considers the provider's credential
// policy flags alongside central (DB/deployment) keys and the user's
// personal keys.
func ResolveUserProviderKeys(
fallback ProviderAPIKeys,
providers []ConfiguredProvider,
userKeys []UserProviderKey,
) (ProviderAPIKeys, map[string]ProviderAvailability) {
merged := ProviderAPIKeys{
OpenAI: strings.TrimSpace(fallback.OpenAI),
Anthropic: strings.TrimSpace(fallback.Anthropic),
ByProvider: map[string]string{},
BaseURLByProvider: map[string]string{},
}
for provider, apiKey := range fallback.ByProvider {
normalizedProvider := NormalizeProvider(provider)
if normalizedProvider == "" {
continue
}
if key := strings.TrimSpace(apiKey); key != "" {
merged.ByProvider[normalizedProvider] = key
}
}
for provider, baseURL := range fallback.BaseURLByProvider {
normalizedProvider := NormalizeProvider(provider)
if normalizedProvider == "" {
continue
}
if url := strings.TrimSpace(baseURL); url != "" {
merged.BaseURLByProvider[normalizedProvider] = url
}
}
if merged.OpenAI != "" {
merged.ByProvider[fantasyopenai.Name] = merged.OpenAI
}
if merged.Anthropic != "" {
merged.ByProvider[fantasyanthropic.Name] = merged.Anthropic
}
userKeyByProviderID := make(map[uuid.UUID]string, len(userKeys))
for _, userKey := range userKeys {
if userKey.ChatProviderID == uuid.Nil {
continue
}
if key := strings.TrimSpace(userKey.APIKey); key != "" {
userKeyByProviderID[userKey.ChatProviderID] = key
}
}
availabilityByProvider := make(map[string]ProviderAvailability, len(providers))
for _, provider := range providers {
normalizedProvider := NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
if url := strings.TrimSpace(provider.BaseURL); url != "" {
merged.BaseURLByProvider[normalizedProvider] = url
}
var userKey string
if provider.ProviderID != uuid.Nil {
userKey = userKeyByProviderID[provider.ProviderID]
}
var centralKey string
if provider.CentralAPIKeyEnabled {
if key := strings.TrimSpace(provider.APIKey); key != "" {
centralKey = key
} else {
centralKey = fallback.APIKey(normalizedProvider)
}
}
resolved := ProviderAvailability{}
chosenKey := ""
switch {
case provider.AllowUserAPIKey && userKey != "":
chosenKey = userKey
resolved.Available = true
case centralKey != "":
if !provider.AllowUserAPIKey || provider.AllowCentralAPIKeyFallback {
chosenKey = centralKey
resolved.Available = true
} else {
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
}
case provider.AllowUserAPIKey && provider.AllowCentralAPIKeyFallback && provider.CentralAPIKeyEnabled:
// When users can add their own key, a missing central fallback key is
// still something the user can remedy.
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
case provider.AllowUserAPIKey:
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
default:
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
}
setResolvedProviderAPIKey(&merged, normalizedProvider, chosenKey)
availabilityByProvider[normalizedProvider] = resolved
}
return merged, availabilityByProvider
}
func NewModelCatalog(keys ProviderAPIKeys) *ModelCatalog {
return &ModelCatalog{
keys: keys,
func setResolvedProviderAPIKey(keys *ProviderAPIKeys, provider string, apiKey string) {
normalizedProvider := NormalizeProvider(provider)
if normalizedProvider == "" {
return
}
if keys.ByProvider == nil {
keys.ByProvider = map[string]string{}
}
delete(keys.ByProvider, normalizedProvider)
trimmedKey := strings.TrimSpace(apiKey)
switch normalizedProvider {
case fantasyopenai.Name:
keys.OpenAI = trimmedKey
case fantasyanthropic.Name:
keys.Anthropic = trimmedKey
}
if trimmedKey != "" {
keys.ByProvider[normalizedProvider] = trimmedKey
}
}
type ModelCatalog struct{}
func NewModelCatalog() *ModelCatalog {
return &ModelCatalog{}
}
// ListConfiguredModels returns a model catalog from enabled DB-backed model
// configs. The second return value reports whether DB-backed models were used.
func (c *ModelCatalog) ListConfiguredModels(
func (*ModelCatalog) ListConfiguredModels(
configuredProviders []ConfiguredProvider,
configuredModels []ConfiguredModel,
availabilityByProvider map[string]ProviderAvailability,
enabledProviders map[string]struct{},
) (codersdk.ChatModelsResponse, bool) {
if len(configuredModels) == 0 {
return codersdk.ChatModelsResponse{}, false
@@ -247,11 +389,14 @@ func (c *ModelCatalog) ListConfiguredModels(
return codersdk.ChatModelsResponse{}, false
}
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
response := codersdk.ChatModelsResponse{
Providers: make([]codersdk.ChatModelProvider, 0, len(providers)),
}
for _, provider := range providers {
if _, ok := enabledProviders[provider]; !ok {
continue
}
models := modelsByProvider[provider]
sortChatModels(models)
@@ -259,11 +404,14 @@ func (c *ModelCatalog) ListConfiguredModels(
Provider: provider,
Models: models,
}
if keys.APIKey(provider) == "" {
if avail, ok := availabilityByProvider[provider]; ok {
result.Available = avail.Available
if !avail.Available {
result.UnavailableReason = avail.UnavailableReason
}
} else {
result.Available = false
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
} else {
result.Available = true
}
response.Providers = append(response.Providers, result)
@@ -273,25 +421,32 @@ func (c *ModelCatalog) ListConfiguredModels(
}
// ListConfiguredProviderAvailability returns provider availability derived from
// deployment/env keys merged with enabled DB provider keys.
func (c *ModelCatalog) ListConfiguredProviderAvailability(
configuredProviders []ConfiguredProvider,
// the policy-aware availability map for enabled providers.
func (*ModelCatalog) ListConfiguredProviderAvailability(
availabilityByProvider map[string]ProviderAvailability,
enabledProviders map[string]struct{},
) codersdk.ChatModelsResponse {
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
response := codersdk.ChatModelsResponse{
Providers: make([]codersdk.ChatModelProvider, 0, len(supportedProviderNames)),
}
for _, provider := range supportedProviderNames {
if _, ok := enabledProviders[provider]; !ok {
continue
}
result := codersdk.ChatModelProvider{
Provider: provider,
Models: []codersdk.ChatModel{},
}
if keys.APIKey(provider) == "" {
if avail, ok := availabilityByProvider[provider]; ok {
result.Available = avail.Available
if !avail.Available {
result.UnavailableReason = avail.UnavailableReason
}
} else {
result.Available = false
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
} else {
result.Available = true
}
response.Providers = append(response.Providers, result)
@@ -300,6 +455,27 @@ func (c *ModelCatalog) ListConfiguredProviderAvailability(
return response
}
// PruneDisabledProviderKeys removes entries from keys that do not
// belong to an enabled provider. It clears ByProvider and
// BaseURLByProvider entries for disabled providers and zeroes the
// legacy OpenAI and Anthropic fields when those providers are not
// enabled.
func PruneDisabledProviderKeys(keys *ProviderAPIKeys, enabledProviders map[string]struct{}) {
for provider := range keys.ByProvider {
if _, ok := enabledProviders[provider]; ok {
continue
}
delete(keys.ByProvider, provider)
delete(keys.BaseURLByProvider, provider)
}
if _, ok := enabledProviders[NormalizeProvider("openai")]; !ok {
keys.OpenAI = ""
}
if _, ok := enabledProviders[NormalizeProvider("anthropic")]; !ok {
keys.Anthropic = ""
}
}
func newChatModel(provider, modelID, displayName string) codersdk.ChatModel {
name := strings.TrimSpace(displayName)
if name == "" {
@@ -21,6 +21,166 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestResolveUserProviderKeys(t *testing.T) {
t.Parallel()
configuredProvider := func(id uuid.UUID, provider string, centralEnabled bool, centralKey string, allowUser bool, allowCentralFallback bool) chatprovider.ConfiguredProvider {
return chatprovider.ConfiguredProvider{
ProviderID: id,
Provider: provider,
APIKey: centralKey,
CentralAPIKeyEnabled: centralEnabled,
AllowUserAPIKey: allowUser,
AllowCentralAPIKeyFallback: allowCentralFallback,
}
}
userProviderKey := func(id uuid.UUID, apiKey string) chatprovider.UserProviderKey {
return chatprovider.UserProviderKey{
ChatProviderID: id,
APIKey: apiKey,
}
}
openAIProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000001")
anthropicProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
tests := []struct {
name string
fallback chatprovider.ProviderAPIKeys
providers []chatprovider.ConfiguredProvider
userKeys []chatprovider.UserProviderKey
wantAvailability map[string]chatprovider.ProviderAvailability
wantKeys map[string]string
}{
{
name: "CentralOnlyKeyPresent",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-central",
},
},
{
name: "CentralOnlyKeyMissing",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", false, false)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey},
},
wantKeys: map[string]string{
fantasyopenai.Name: "",
},
},
{
name: "UserOnlyUserHasKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-user",
},
},
{
name: "UserOnlyUserHasNoKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
},
wantKeys: map[string]string{
fantasyopenai.Name: "",
},
},
{
name: "BothEnabledFallbackOffUserHasKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-user",
},
},
{
name: "BothEnabledFallbackOffUserHasNoKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
},
wantKeys: map[string]string{
fantasyopenai.Name: "",
},
},
{
name: "BothEnabledFallbackOnUserHasKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-user",
},
},
{
name: "BothEnabledFallbackOnUserHasNoKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-central",
},
},
{
name: "BothEnabledFallbackOnCentralKeyEmptyUserHasNoKey",
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", true, true)},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
},
wantKeys: map[string]string{
fantasyopenai.Name: "",
},
},
{
name: "MultipleProvidersDifferentPolicies",
providers: []chatprovider.ConfiguredProvider{
configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false),
configuredProvider(anthropicProviderID, fantasyanthropic.Name, false, "", true, false),
},
wantAvailability: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {Available: true},
fantasyanthropic.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
},
wantKeys: map[string]string{
fantasyopenai.Name: "sk-central",
fantasyanthropic.Name: "",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
keys, availability := chatprovider.ResolveUserProviderKeys(tt.fallback, tt.providers, tt.userKeys)
require.Len(t, availability, len(tt.wantAvailability))
for provider, wantAvailability := range tt.wantAvailability {
gotAvailability, ok := availability[provider]
require.True(t, ok, "expected availability for provider %q", provider)
require.Equal(t, wantAvailability, gotAvailability)
require.Equal(t, tt.wantKeys[provider], keys.APIKey(provider))
}
})
}
}
func TestReasoningEffortFromChat(t *testing.T) {
t.Parallel()
@@ -91,6 +251,413 @@ func TestReasoningEffortFromChat(t *testing.T) {
}
}
func TestResolveUserProviderKeys_UnavailableReason(t *testing.T) {
t.Parallel()
tests := []struct {
name string
provider chatprovider.ConfiguredProvider
wantReason codersdk.ChatModelProviderUnavailableReason
}{
{
name: "FallbackConfiguredWithoutCentralKeyReturnsUserAPIKeyRequired",
provider: chatprovider.ConfiguredProvider{
Provider: "anthropic",
CentralAPIKeyEnabled: true,
AllowUserAPIKey: true,
AllowCentralAPIKeyFallback: true,
},
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
},
{
name: "UserKeyRequiredWithoutFallback",
provider: chatprovider.ConfiguredProvider{
Provider: "anthropic",
CentralAPIKeyEnabled: true,
AllowUserAPIKey: true,
},
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
keys, availability := chatprovider.ResolveUserProviderKeys(
chatprovider.ProviderAPIKeys{},
[]chatprovider.ConfiguredProvider{tt.provider},
nil,
)
require.Empty(t, keys.APIKey(tt.provider.Provider))
resolved, ok := availability[tt.provider.Provider]
require.True(t, ok)
require.False(t, resolved.Available)
require.Equal(t, tt.wantReason, resolved.UnavailableReason)
})
}
}
func TestListConfiguredModels_PolicyAwareAvailability(t *testing.T) {
t.Parallel()
configuredProvider := func(provider string, apiKey string) chatprovider.ConfiguredProvider {
return chatprovider.ConfiguredProvider{
ProviderID: uuid.New(),
Provider: provider,
APIKey: apiKey,
}
}
enabledProviders := func(providers ...string) map[string]struct{} {
result := make(map[string]struct{}, len(providers))
for _, provider := range providers {
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
}
return result
}
catalog := chatprovider.NewModelCatalog()
tests := []struct {
name string
configuredProviders []chatprovider.ConfiguredProvider
configuredModels []chatprovider.ConfiguredModel
availabilityByProvider map[string]chatprovider.ProviderAvailability
enabledProviders map[string]struct{}
want codersdk.ChatModelsResponse
}{
{
name: "PolicyUnavailableOverridesConfiguredKey",
configuredProviders: []chatprovider.ConfiguredProvider{
configuredProvider(fantasyopenai.Name, "sk-central"),
},
configuredModels: []chatprovider.ConfiguredModel{{
Provider: fantasyopenai.Name,
Model: "gpt-4",
}},
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
fantasyopenai.Name: {
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
},
},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyopenai.Name,
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
Models: []codersdk.ChatModel{{
ID: fantasyopenai.Name + ":gpt-4",
Provider: fantasyopenai.Name,
Model: "gpt-4",
DisplayName: "gpt-4",
}},
}}},
},
{
name: "PolicyAvailableMarksProviderAvailable",
configuredProviders: []chatprovider.ConfiguredProvider{
configuredProvider(fantasyanthropic.Name, "sk-central"),
},
configuredModels: []chatprovider.ConfiguredModel{{
Provider: fantasyanthropic.Name,
Model: "claude-3-5-sonnet",
}},
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
fantasyanthropic.Name: {Available: true},
},
enabledProviders: enabledProviders(fantasyanthropic.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyanthropic.Name,
Available: true,
Models: []codersdk.ChatModel{{
ID: fantasyanthropic.Name + ":claude-3-5-sonnet",
Provider: fantasyanthropic.Name,
Model: "claude-3-5-sonnet",
DisplayName: "claude-3-5-sonnet",
}},
}}},
},
{
name: "DisabledProviderOmitted",
configuredProviders: []chatprovider.ConfiguredProvider{
configuredProvider(fantasyanthropic.Name, "sk-anthropic"),
configuredProvider(fantasyopenai.Name, "sk-openai"),
},
configuredModels: []chatprovider.ConfiguredModel{
{Provider: fantasyanthropic.Name, Model: "claude-3-5-sonnet"},
{Provider: fantasyopenai.Name, Model: "gpt-4"},
},
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
fantasyanthropic.Name: {Available: true},
fantasyopenai.Name: {Available: true},
},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyopenai.Name,
Available: true,
Models: []codersdk.ChatModel{{
ID: fantasyopenai.Name + ":gpt-4",
Provider: fantasyopenai.Name,
Model: "gpt-4",
DisplayName: "gpt-4",
}},
}}},
},
{
name: "MissingAvailabilityDefaultsToMissingAPIKey",
configuredProviders: []chatprovider.ConfiguredProvider{
configuredProvider(fantasyopenai.Name, "sk-central"),
},
configuredModels: []chatprovider.ConfiguredModel{{
Provider: fantasyopenai.Name,
Model: "gpt-4o",
}},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyopenai.Name,
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
Models: []codersdk.ChatModel{{
ID: fantasyopenai.Name + ":gpt-4o",
Provider: fantasyopenai.Name,
Model: "gpt-4o",
DisplayName: "gpt-4o",
}},
}}},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := catalog.ListConfiguredModels(
tt.configuredProviders,
tt.configuredModels,
tt.availabilityByProvider,
tt.enabledProviders,
)
require.True(t, ok)
require.Equal(t, tt.want, got)
})
}
}
func TestListConfiguredProviderAvailability_PolicyAwareFiltering(t *testing.T) {
t.Parallel()
enabledProviders := func(providers ...string) map[string]struct{} {
result := make(map[string]struct{}, len(providers))
for _, provider := range providers {
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
}
return result
}
catalog := chatprovider.NewModelCatalog()
tests := []struct {
name string
availabilityByProvider map[string]chatprovider.ProviderAvailability
enabledProviders map[string]struct{}
want codersdk.ChatModelsResponse
}{
{
name: "EnabledProvidersUsePolicyAvailability",
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
fantasyanthropic.Name: {
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
},
fantasyopenai.Name: {Available: true},
},
enabledProviders: enabledProviders(fantasyanthropic.Name, fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{
{
Provider: fantasyanthropic.Name,
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
Models: []codersdk.ChatModel{},
},
{
Provider: fantasyopenai.Name,
Available: true,
Models: []codersdk.ChatModel{},
},
}},
},
{
name: "DisabledSupportedProviderOmitted",
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
fantasyanthropic.Name: {Available: true},
fantasyopenai.Name: {Available: true},
},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyopenai.Name,
Available: true,
Models: []codersdk.ChatModel{},
}}},
},
{
name: "MissingAvailabilityDefaultsToMissingAPIKey",
enabledProviders: enabledProviders(fantasyopenai.Name),
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
Provider: fantasyopenai.Name,
Available: false,
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
Models: []codersdk.ChatModel{},
}}},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := catalog.ListConfiguredProviderAvailability(
tt.availabilityByProvider,
tt.enabledProviders,
)
require.Equal(t, tt.want, got)
})
}
}
func TestPruneDisabledProviderKeys(t *testing.T) {
t.Parallel()
enabledProviders := func(providers ...string) map[string]struct{} {
result := make(map[string]struct{}, len(providers))
for _, provider := range providers {
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
}
return result
}
tests := []struct {
name string
keys chatprovider.ProviderAPIKeys
enabledProviders map[string]struct{}
want chatprovider.ProviderAPIKeys
}{
{
name: "DisabledProviderEntriesRemoved",
keys: chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
fantasyanthropic.Name: "sk-anthropic",
fantasyopenai.Name: "sk-openai",
},
BaseURLByProvider: map[string]string{
fantasyanthropic.Name: "https://anthropic.example.com",
fantasyopenai.Name: "https://openai.example.com",
},
},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
},
},
},
{
name: "OpenAIDisabledClearsLegacyField",
keys: chatprovider.ProviderAPIKeys{
OpenAI: "sk-openai",
Anthropic: "sk-anthropic",
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
fantasyanthropic.Name: "sk-anthropic",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
fantasyanthropic.Name: "https://anthropic.example.com",
},
},
enabledProviders: enabledProviders(fantasyanthropic.Name),
want: chatprovider.ProviderAPIKeys{
Anthropic: "sk-anthropic",
ByProvider: map[string]string{
fantasyanthropic.Name: "sk-anthropic",
},
BaseURLByProvider: map[string]string{
fantasyanthropic.Name: "https://anthropic.example.com",
},
},
},
{
name: "AnthropicDisabledClearsLegacyField",
keys: chatprovider.ProviderAPIKeys{
OpenAI: "sk-openai",
Anthropic: "sk-anthropic",
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
fantasyanthropic.Name: "sk-anthropic",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
fantasyanthropic.Name: "https://anthropic.example.com",
},
},
enabledProviders: enabledProviders(fantasyopenai.Name),
want: chatprovider.ProviderAPIKeys{
OpenAI: "sk-openai",
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
},
},
},
{
name: "AllEnabledLeavesKeysUnchanged",
keys: chatprovider.ProviderAPIKeys{
OpenAI: "sk-openai",
Anthropic: "sk-anthropic",
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
fantasyanthropic.Name: "sk-anthropic",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
fantasyanthropic.Name: "https://anthropic.example.com",
},
},
enabledProviders: enabledProviders(fantasyopenai.Name, fantasyanthropic.Name),
want: chatprovider.ProviderAPIKeys{
OpenAI: "sk-openai",
Anthropic: "sk-anthropic",
ByProvider: map[string]string{
fantasyopenai.Name: "sk-openai",
fantasyanthropic.Name: "sk-anthropic",
},
BaseURLByProvider: map[string]string{
fantasyopenai.Name: "https://openai.example.com",
fantasyanthropic.Name: "https://anthropic.example.com",
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
keys := tt.keys
chatprovider.PruneDisabledProviderKeys(&keys, tt.enabledProviders)
require.Equal(t, tt.want, keys)
})
}
}
func TestCoderHeaders(t *testing.T) {
t.Parallel()
@@ -440,13 +440,14 @@ func seedModelConfig(
t.Helper()
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+8 -7
View File
@@ -122,13 +122,14 @@ func seedInternalChatDeps(
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+4 -3
View File
@@ -959,9 +959,10 @@ func TestWorker(t *testing.T) {
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
+86 -18
View File
@@ -405,6 +405,8 @@ type ChatModelProviderUnavailableReason string
const (
ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key"
ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed"
// #nosec G101
ChatModelProviderUnavailableReasonUserAPIKeyRequired ChatModelProviderUnavailableReason = "user_api_key_required"
)
// ChatModel represents a model in the chat model catalog.
@@ -532,32 +534,57 @@ const (
// ChatProviderConfig is an admin-managed provider configuration.
type ChatProviderConfig struct {
ID uuid.UUID `json:"id" format:"uuid"`
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
HasAPIKey bool `json:"has_api_key"`
BaseURL string `json:"base_url,omitempty"`
Source ChatProviderConfigSource `json:"source"`
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
ID uuid.UUID `json:"id" format:"uuid"`
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
HasAPIKey bool `json:"has_api_key"`
CentralAPIKeyEnabled bool `json:"central_api_key_enabled"`
AllowUserAPIKey bool `json:"allow_user_api_key"`
AllowCentralAPIKeyFallback bool `json:"allow_central_api_key_fallback"`
BaseURL string `json:"base_url,omitempty"`
Source ChatProviderConfigSource `json:"source"`
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
}
// CreateChatProviderConfigRequest creates a chat provider config.
type CreateChatProviderConfigRequest struct {
Provider string `json:"provider"`
DisplayName string `json:"display_name,omitempty"`
APIKey string `json:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
Provider string `json:"provider"`
DisplayName string `json:"display_name,omitempty"`
APIKey string `json:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"`
AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"`
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
}
// UpdateChatProviderConfigRequest updates a chat provider config.
type UpdateChatProviderConfigRequest struct {
DisplayName string `json:"display_name,omitempty"`
APIKey *string `json:"api_key,omitempty"`
BaseURL *string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
DisplayName string `json:"display_name,omitempty"`
APIKey *string `json:"api_key,omitempty"`
BaseURL *string `json:"base_url,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"`
AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"`
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
}
// UserChatProviderConfig is a summary of a provider that allows
// user-supplied keys, as seen from the current user's perspective.
type UserChatProviderConfig struct {
ProviderID uuid.UUID `json:"provider_id" format:"uuid"`
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
HasUserAPIKey bool `json:"has_user_api_key"`
HasCentralAPIKeyFallback bool `json:"has_central_api_key_fallback"`
}
// CreateUserChatProviderKeyRequest creates or replaces a user's API key
// for a provider.
type CreateUserChatProviderKeyRequest struct {
APIKey string `json:"api_key"`
}
// ChatModelConfig is an admin-managed model configuration.
@@ -1332,6 +1359,47 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID
return nil
}
// ListUserChatProviderConfigs returns user-scoped chat provider configs.
func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil)
if err != nil {
return nil, xerrors.Errorf("list user chat provider configs: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var configs []UserChatProviderConfig
return configs, json.NewDecoder(res.Body).Decode(&configs)
}
// UpsertUserChatProviderKey creates or replaces a user API key for a provider.
func (c *ExperimentalClient) UpsertUserChatProviderKey(ctx context.Context, providerID uuid.UUID, req CreateUserChatProviderKeyRequest) (UserChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), req)
if err != nil {
return UserChatProviderConfig{}, xerrors.Errorf("upsert user chat provider key: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return UserChatProviderConfig{}, ReadBodyAsError(res)
}
var config UserChatProviderConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
// DeleteUserChatProviderKey deletes a user API key for a provider.
func (c *ExperimentalClient) DeleteUserChatProviderKey(ctx context.Context, providerID uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), nil)
if err != nil {
return xerrors.Errorf("delete user chat provider key: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// ListChatModelConfigs returns admin-managed chat model configs.
func (c *ExperimentalClient) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/model-configs", nil)
+17 -13
View File
@@ -104,13 +104,14 @@ func seedChatDependencies(
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: safetyNet.URL,
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: safetyNet.URL,
CentralApiKeyEnabled: true,
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
@@ -186,12 +187,15 @@ func setOpenAIProviderBaseURL(
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
CentralApiKeyEnabled: true,
AllowUserApiKey: false,
AllowCentralApiKeyFallback: false,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
})
require.NoError(t, err)
}
+65 -14
View File
@@ -73,12 +73,35 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
}
}
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if strings.TrimSpace(userProviderKey.APIKey) == "" {
continue
}
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
return nil
}, &database.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
if err != nil {
return xerrors.Errorf("update user links: %w", err)
return xerrors.Errorf("update user tokens and chat provider keys: %w", err)
}
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
@@ -97,12 +120,15 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
Enabled: provider.Enabled,
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
@@ -189,12 +215,32 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
}
}
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
if err != nil {
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
}
for _, userProviderKey := range userProviderKeys {
if !userProviderKey.ApiKeyKeyID.Valid {
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
UserID: userProviderKey.UserID,
ChatProviderID: userProviderKey.ChatProviderID,
APIKey: userProviderKey.APIKey,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
}); err != nil {
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
}
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
}
return nil
}, &database.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
if err != nil {
return xerrors.Errorf("update user links: %w", err)
return xerrors.Errorf("update user tokens and chat provider keys: %w", err)
}
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
@@ -209,12 +255,15 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
continue
}
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
Enabled: provider.Enabled,
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: provider.BaseUrl,
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
ID: provider.ID,
}); err != nil {
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
}
@@ -241,6 +290,8 @@ DELETE FROM user_links
DELETE FROM external_auth_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL;
DELETE FROM user_chat_provider_keys
WHERE api_key_key_id IS NOT NULL;
UPDATE chat_providers
SET api_key = '',
api_key_key_id = NULL
+51
View File
@@ -471,6 +471,57 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat
return provider, nil
}
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
}
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.UserChatProviderKey{}, err
}
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
if err != nil {
return database.UserChatProviderKey{}, err
}
if err := db.decryptUserChatProviderKey(&key); err != nil {
return database.UserChatProviderKey{}, err
}
return key, nil
}
// decryptMCPServerConfig decrypts all encrypted fields on a
// single MCPServerConfig in place.
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
+110
View File
@@ -1177,3 +1177,113 @@ func TestMCPServerUserTokens(t *testing.T) {
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
})
}
func TestUserChatProviderKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
const (
//nolint:gosec // test credentials
initialAPIKey = "sk-initial-api-key-value"
//nolint:gosec // test credentials
updatedAPIKey = "sk-updated-api-key-value"
)
insertProviderAndKey := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) (database.ChatProvider, database.UserChatProviderKey) {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
provider, err := crypt.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "",
Enabled: true,
AllowUserApiKey: true,
})
require.NoError(t, err)
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: user.ID,
ChatProviderID: provider.ID,
APIKey: initialAPIKey,
})
require.NoError(t, err)
require.Equal(t, initialAPIKey, key.APIKey)
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
return provider, key
}
getUserChatProviderKey := func(t *testing.T, store interface {
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
}, userID uuid.UUID, providerID uuid.UUID,
) database.UserChatProviderKey {
t.Helper()
keys, err := store.GetUserChatProviderKeys(ctx, userID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, providerID, keys[0].ChatProviderID)
return keys[0]
}
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, key.ID, got.ID)
require.Equal(t, initialAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
})
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
t.Parallel()
_, crypt, ciphers := setup(t)
_, key := insertProviderAndKey(t, crypt, ciphers)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, key.ID, keys[0].ID)
require.Equal(t, initialAPIKey, keys[0].APIKey)
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
})
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
provider, key := insertProviderAndKey(t, crypt, ciphers)
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: key.UserID,
ChatProviderID: provider.ID,
APIKey: updatedAPIKey,
})
require.NoError(t, err)
require.Equal(t, key.ID, updated.ID)
require.Equal(t, key.CreatedAt, updated.CreatedAt)
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
require.Equal(t, updatedAPIKey, updated.APIKey)
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
require.Equal(t, updatedAPIKey, got.APIKey)
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, updatedAPIKey, keys[0].APIKey)
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
})
}
+34 -2
View File
@@ -1776,10 +1776,11 @@ export interface ChatModelProviderOptions {
// From codersdk/chats.go
export type ChatModelProviderUnavailableReason =
| "fetch_failed"
| "missing_api_key";
| "missing_api_key"
| "user_api_key_required";
export const ChatModelProviderUnavailableReasons: ChatModelProviderUnavailableReason[] =
["fetch_failed", "missing_api_key"];
["fetch_failed", "missing_api_key", "user_api_key_required"];
// From codersdk/chats.go
/**
@@ -1836,6 +1837,9 @@ export interface ChatProviderConfig {
readonly display_name: string;
readonly enabled: boolean;
readonly has_api_key: boolean;
readonly central_api_key_enabled: boolean;
readonly allow_user_api_key: boolean;
readonly allow_central_api_key_fallback: boolean;
readonly base_url?: string;
readonly source: ChatProviderConfigSource;
readonly created_at?: string;
@@ -2362,6 +2366,9 @@ export interface CreateChatProviderConfigRequest {
readonly api_key?: string;
readonly base_url?: string;
readonly enabled?: boolean;
readonly central_api_key_enabled?: boolean;
readonly allow_user_api_key?: boolean;
readonly allow_central_api_key_fallback?: boolean;
}
// From codersdk/chats.go
@@ -2643,6 +2650,15 @@ export interface CreateTokenRequest {
readonly allow_list?: readonly APIAllowListTarget[];
}
// From codersdk/chats.go
/**
* CreateUserChatProviderKeyRequest creates or replaces a user's API key
* for a provider.
*/
export interface CreateUserChatProviderKeyRequest {
readonly api_key: string;
}
// From codersdk/users.go
export interface CreateUserRequestWithOrgs {
readonly email: string;
@@ -7150,6 +7166,9 @@ export interface UpdateChatProviderConfigRequest {
readonly api_key?: string;
readonly base_url?: string;
readonly enabled?: boolean;
readonly central_api_key_enabled?: boolean;
readonly allow_user_api_key?: boolean;
readonly allow_central_api_key_fallback?: boolean;
}
// From codersdk/chats.go
@@ -7694,6 +7713,19 @@ export interface UserChatCustomPrompt {
readonly custom_prompt: string;
}
// From codersdk/chats.go
/**
* UserChatProviderConfig is a summary of a provider that allows
* user-supplied keys, as seen from the current user's perspective.
*/
export interface UserChatProviderConfig {
readonly provider_id: string;
readonly provider: string;
readonly display_name: string;
readonly has_user_api_key: boolean;
readonly has_central_api_key_fallback: boolean;
}
// From codersdk/insights.go
/**
* UserLatency shows the connection latency for a user.
@@ -20,6 +20,10 @@ const createProviderConfig = (
display_name: overrides.display_name ?? "",
enabled: overrides.enabled ?? true,
has_api_key: overrides.has_api_key ?? false,
central_api_key_enabled: overrides.central_api_key_enabled ?? true,
allow_user_api_key: overrides.allow_user_api_key ?? false,
allow_central_api_key_fallback:
overrides.allow_central_api_key_fallback ?? false,
base_url: overrides.base_url ?? "",
source: overrides.source ?? "database",
created_at: overrides.created_at ?? now,
@@ -14,6 +14,9 @@ const providerState: ProviderState = {
display_name: "OpenAI",
enabled: true,
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
base_url: undefined,
source: "database",
created_at: "2025-01-01T00:00:00Z",