mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: provider key policies and user provider settings (#23751)
This commit is contained in:
@@ -110,6 +110,9 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
- Experimental chat endpoints in `coderd/exp_chats.go` omit swagger
|
||||
annotations entirely. Do not add `@Summary`, `@Router`, or other
|
||||
swagger comments to handlers in that file.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
|
||||
+8
-1
@@ -782,7 +782,7 @@ func New(options *Options) *API {
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
@@ -1221,6 +1221,13 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/user-provider-configs", func(r chi.Router) {
|
||||
r.Get("/", api.listUserChatProviderConfigs)
|
||||
r.Route("/{providerConfig}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertUserChatProviderKey)
|
||||
r.Delete("/", api.deleteUserChatProviderKey)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
@@ -32,4 +33,5 @@ const (
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys
|
||||
)
|
||||
|
||||
@@ -2137,6 +2137,17 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat
|
||||
return q.db.DeleteUserChatCompactionThreshold(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
@@ -4024,6 +4035,17 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserChatProviderKeys(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
@@ -6454,6 +6476,17 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U
|
||||
return q.db.UpdateUserChatCustomPrompt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpdateUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
return deleteQ(q.log, q.auth, q.db.GetUserByID, q.db.UpdateUserDeletedByID)(ctx, id)
|
||||
}
|
||||
@@ -7181,6 +7214,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return q.db.UpsertTemplateUsageStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return q.db.UpsertUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -2407,6 +2407,36 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt")
|
||||
}))
|
||||
s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key})
|
||||
}))
|
||||
s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns()
|
||||
}))
|
||||
s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"}
|
||||
key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
|
||||
|
||||
@@ -696,6 +696,14 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
@@ -2528,6 +2536,14 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
@@ -4560,6 +4576,14 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateUserDeletedByID(ctx, id)
|
||||
@@ -5152,6 +5176,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg)
|
||||
|
||||
@@ -1171,6 +1171,20 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4729,6 +4743,21 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8605,6 +8634,21 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserDeletedByID mocks base method.
|
||||
func (m *MockStore) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9671,6 +9715,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserChatProviderKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertWebpushVAPIDKeys mocks base method.
|
||||
func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+31
-1
@@ -1341,7 +1341,11 @@ CREATE TABLE chat_providers (
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
base_url text DEFAULT ''::text NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text])))
|
||||
central_api_key_enabled boolean DEFAULT true NOT NULL,
|
||||
allow_user_api_key boolean DEFAULT false NOT NULL,
|
||||
allow_central_api_key_fallback boolean DEFAULT false NOT NULL,
|
||||
CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))),
|
||||
CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))))
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted';
|
||||
@@ -2752,6 +2756,17 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of
|
||||
|
||||
COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.';
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
chat_provider_id uuid NOT NULL,
|
||||
api_key text NOT NULL,
|
||||
api_key_key_id text,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text))
|
||||
);
|
||||
|
||||
CREATE TABLE user_configs (
|
||||
user_id uuid NOT NULL,
|
||||
key character varying(256) NOT NULL,
|
||||
@@ -3548,6 +3563,12 @@ ALTER TABLE ONLY usage_events_daily
|
||||
ALTER TABLE ONLY usage_events
|
||||
ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
|
||||
@@ -4258,6 +4279,15 @@ ALTER TABLE ONLY templates
|
||||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_chat_provider_keys
|
||||
ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_configs
|
||||
ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ const (
|
||||
ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE;
|
||||
ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT;
|
||||
ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
DROP TABLE IF EXISTS user_chat_provider_keys;
|
||||
|
||||
ALTER TABLE chat_providers DROP CONSTRAINT IF EXISTS valid_credential_policy;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
DROP COLUMN IF EXISTS central_api_key_enabled,
|
||||
DROP COLUMN IF EXISTS allow_user_api_key,
|
||||
DROP COLUMN IF EXISTS allow_central_api_key_fallback;
|
||||
@@ -0,0 +1,24 @@
|
||||
ALTER TABLE chat_providers
|
||||
ADD COLUMN central_api_key_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
ADD COLUMN allow_user_api_key BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN allow_central_api_key_fallback BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
ALTER TABLE chat_providers
|
||||
ADD CONSTRAINT valid_credential_policy CHECK (
|
||||
(central_api_key_enabled OR allow_user_api_key) AND
|
||||
(
|
||||
NOT allow_central_api_key_fallback OR
|
||||
(central_api_key_enabled AND allow_user_api_key)
|
||||
)
|
||||
);
|
||||
|
||||
CREATE TABLE user_chat_provider_keys (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
chat_provider_id UUID NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE,
|
||||
api_key TEXT NOT NULL CHECK (api_key != ''),
|
||||
api_key_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (user_id, chat_provider_id)
|
||||
);
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
INSERT INTO user_chat_provider_keys (
|
||||
user_id,
|
||||
chat_provider_id,
|
||||
api_key,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
'0a8b2f84-b5a8-4c44-8c9f-e58c44a534a7',
|
||||
'fixture-test-key',
|
||||
'2025-01-01 00:00:00+00',
|
||||
'2025-01-01 00:00:00+00'
|
||||
FROM users
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
@@ -4264,12 +4264,15 @@ type ChatProvider struct {
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
// The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
type ChatQueuedMessage struct {
|
||||
@@ -5222,6 +5225,16 @@ type User struct {
|
||||
ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"`
|
||||
}
|
||||
|
||||
type UserChatProviderKey struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Key string `db:"key" json:"key"`
|
||||
|
||||
@@ -150,6 +150,7 @@ type sqlcQuerier interface {
|
||||
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
@@ -577,6 +578,7 @@ type sqlcQuerier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
// Returns the minimum (most restrictive) group limit for a user.
|
||||
@@ -927,6 +929,7 @@ type sqlcQuerier interface {
|
||||
UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error
|
||||
UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error)
|
||||
UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error)
|
||||
UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error
|
||||
UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error
|
||||
UpdateUserHashedOneTimePasscode(ctx context.Context, arg UpdateUserHashedOneTimePasscodeParams) error
|
||||
@@ -1015,6 +1018,7 @@ type sqlcQuerier interface {
|
||||
// used to store the data, and the minutes are summed for each user and template
|
||||
// combination. The result is stored in the template_usage_stats table.
|
||||
UpsertTemplateUsageStats(ctx context.Context) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error)
|
||||
|
||||
@@ -1261,10 +1261,11 @@ func TestGetAuthorizedChats(t *testing.T) {
|
||||
// Create FK dependencies: a chat provider and model config.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -9456,10 +9457,11 @@ func TestInsertChatMessages(t *testing.T) {
|
||||
provider := "openai"
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: provider,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -9621,10 +9623,11 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
|
||||
// A chat_providers row is required as a FK for model configs.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -9992,10 +9995,11 @@ func TestGetPRInsights(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10516,10 +10520,11 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
// timed test context doesn't tick during DB init.
|
||||
bg := context.Background()
|
||||
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10696,10 +10701,11 @@ func TestChatLabels(t *testing.T) {
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10907,10 +10913,11 @@ func TestChatHasUnread(t *testing.T) {
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
+181
-22
@@ -3799,7 +3799,7 @@ func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) e
|
||||
|
||||
const getChatProviderByID = `-- name: GetChatProviderByID :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
@@ -3820,13 +3820,16 @@ func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (Cha
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
@@ -3847,13 +3850,16 @@ func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider str
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatProviders = `-- name: GetChatProviders :many
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
ORDER BY
|
||||
@@ -3880,6 +3886,9 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3896,7 +3905,7 @@ func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, erro
|
||||
|
||||
const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many
|
||||
SELECT
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
FROM
|
||||
chat_providers
|
||||
WHERE
|
||||
@@ -3925,6 +3934,9 @@ func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvide
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3947,7 +3959,10 @@ INSERT INTO chat_providers (
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
$1::text,
|
||||
$2::text,
|
||||
@@ -3955,20 +3970,26 @@ INSERT INTO chat_providers (
|
||||
$4::text,
|
||||
$5::text,
|
||||
$6::uuid,
|
||||
$7::boolean
|
||||
$7::boolean,
|
||||
$8::boolean,
|
||||
$9::boolean,
|
||||
$10::boolean
|
||||
)
|
||||
RETURNING
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
`
|
||||
|
||||
type InsertChatProviderParams struct {
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) {
|
||||
@@ -3980,6 +4001,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi
|
||||
arg.ApiKeyKeyID,
|
||||
arg.CreatedBy,
|
||||
arg.Enabled,
|
||||
arg.CentralApiKeyEnabled,
|
||||
arg.AllowUserApiKey,
|
||||
arg.AllowCentralApiKeyFallback,
|
||||
)
|
||||
var i ChatProvider
|
||||
err := row.Scan(
|
||||
@@ -3993,6 +4017,9 @@ func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProvi
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4006,20 +4033,26 @@ SET
|
||||
base_url = $3::text,
|
||||
api_key_key_id = $4::text,
|
||||
enabled = $5::boolean,
|
||||
central_api_key_enabled = $6::boolean,
|
||||
allow_user_api_key = $7::boolean,
|
||||
allow_central_api_key_fallback = $8::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
id = $9::uuid
|
||||
RETURNING
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url
|
||||
id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback
|
||||
`
|
||||
|
||||
type UpdateChatProviderParams struct {
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
BaseUrl string `db:"base_url" json:"base_url"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
Enabled bool `db:"enabled" json:"enabled"`
|
||||
CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"`
|
||||
AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"`
|
||||
AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) {
|
||||
@@ -4029,6 +4062,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi
|
||||
arg.BaseUrl,
|
||||
arg.ApiKeyKeyID,
|
||||
arg.Enabled,
|
||||
arg.CentralApiKeyEnabled,
|
||||
arg.AllowUserApiKey,
|
||||
arg.AllowCentralApiKeyFallback,
|
||||
arg.ID,
|
||||
)
|
||||
var i ChatProvider
|
||||
@@ -4043,6 +4079,9 @@ func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProvi
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.BaseUrl,
|
||||
&i.CentralApiKeyEnabled,
|
||||
&i.AllowUserApiKey,
|
||||
&i.AllowCentralApiKeyFallback,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -22617,6 +22656,126 @@ func (q *sqlQuerier) UpdateUserSecret(ctx context.Context, arg UpdateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2
|
||||
`
|
||||
|
||||
type DeleteUserChatProviderKeyParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many
|
||||
SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserChatProviderKey
|
||||
for rows.Next() {
|
||||
var i UserChatProviderKey
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW()
|
||||
WHERE user_id = $3 AND chat_provider_id = $4
|
||||
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateUserChatProviderKeyParams struct {
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserChatProviderKey,
|
||||
arg.APIKey,
|
||||
arg.ApiKeyKeyID,
|
||||
arg.UserID,
|
||||
arg.ChatProviderID,
|
||||
)
|
||||
var i UserChatProviderKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES ($1, $2, $3, $4::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = $3,
|
||||
api_key_key_id = $4::text,
|
||||
updated_at = NOW()
|
||||
RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpsertUserChatProviderKeyParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"`
|
||||
APIKey string `db:"api_key" json:"api_key"`
|
||||
ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) {
|
||||
row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey,
|
||||
arg.UserID,
|
||||
arg.ChatProviderID,
|
||||
arg.APIKey,
|
||||
arg.ApiKeyKeyID,
|
||||
)
|
||||
var i UserChatProviderKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.ChatProviderID,
|
||||
&i.APIKey,
|
||||
&i.ApiKeyKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const allUserIDs = `-- name: AllUserIDs :many
|
||||
SELECT DISTINCT id FROM USERS
|
||||
WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END
|
||||
|
||||
@@ -40,7 +40,10 @@ INSERT INTO chat_providers (
|
||||
base_url,
|
||||
api_key_key_id,
|
||||
created_by,
|
||||
enabled
|
||||
enabled,
|
||||
central_api_key_enabled,
|
||||
allow_user_api_key,
|
||||
allow_central_api_key_fallback
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@display_name::text,
|
||||
@@ -48,7 +51,10 @@ INSERT INTO chat_providers (
|
||||
@base_url::text,
|
||||
sqlc.narg('api_key_key_id')::text,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@enabled::boolean
|
||||
@enabled::boolean,
|
||||
@central_api_key_enabled::boolean,
|
||||
@allow_user_api_key::boolean,
|
||||
@allow_central_api_key_fallback::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -62,6 +68,9 @@ SET
|
||||
base_url = @base_url::text,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
enabled = @enabled::boolean,
|
||||
central_api_key_enabled = @central_api_key_enabled::boolean,
|
||||
allow_user_api_key = @allow_user_api_key::boolean,
|
||||
allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- name: GetUserChatProviderKeys :many
|
||||
SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC;
|
||||
|
||||
-- name: UpsertUserChatProviderKey :one
|
||||
INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id)
|
||||
VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text)
|
||||
ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET
|
||||
api_key = @api_key,
|
||||
api_key_key_id = sqlc.narg('api_key_key_id')::text,
|
||||
updated_at = NOW()
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserChatProviderKey :one
|
||||
UPDATE user_chat_provider_keys
|
||||
SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW()
|
||||
WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserChatProviderKey :exec
|
||||
DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id;
|
||||
@@ -90,6 +90,8 @@ const (
|
||||
UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
|
||||
UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type);
|
||||
UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id);
|
||||
UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id);
|
||||
UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key);
|
||||
UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id);
|
||||
UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
|
||||
|
||||
+447
-57
@@ -520,6 +520,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
//nolint:gocritic // System context required to read enabled chat models.
|
||||
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
@@ -555,14 +556,24 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
configuredProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
enabledProviderNames := make(map[string]struct{}, len(enabledProviders))
|
||||
for _, provider := range enabledProviders {
|
||||
configuredProviders = append(
|
||||
configuredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
},
|
||||
)
|
||||
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledProviderNames[normalizedProvider] = struct{}{}
|
||||
}
|
||||
configuredModels := make(
|
||||
[]chatprovider.ConfiguredModel, 0, len(enabledModels),
|
||||
@@ -575,18 +586,38 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
keys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to load user chat provider keys.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
||||
for _, userKey := range userKeyRows {
|
||||
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
||||
ChatProviderID: userKey.ChatProviderID,
|
||||
APIKey: userKey.APIKey,
|
||||
})
|
||||
}
|
||||
|
||||
_, providerAvailability := chatprovider.ResolveUserProviderKeys(
|
||||
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
configuredProviders,
|
||||
userKeys,
|
||||
)
|
||||
catalog := chatprovider.NewModelCatalog(keys)
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
var response codersdk.ChatModelsResponse
|
||||
if configured, ok := catalog.ListConfiguredModels(
|
||||
configuredProviders, configuredModels,
|
||||
configuredProviders, configuredModels, providerAvailability, enabledProviderNames,
|
||||
); ok {
|
||||
response = configured
|
||||
} else {
|
||||
response = catalog.ListConfiguredProviderAvailability(configuredProviders)
|
||||
response = catalog.ListConfiguredProviderAvailability(
|
||||
providerAvailability,
|
||||
enabledProviderNames,
|
||||
)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
@@ -3926,9 +3957,13 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
for _, provider := range enabledProviders {
|
||||
normalizedProvider := normalizeChatProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledConfiguredProviders = append(
|
||||
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
Provider: normalizedProvider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
},
|
||||
@@ -3936,7 +3971,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
enabledConfiguredProviders,
|
||||
)
|
||||
effectiveKeys = chatprovider.MergeProviderAPIKeys(
|
||||
@@ -3952,7 +3987,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
resp,
|
||||
convertChatProviderConfig(
|
||||
configured,
|
||||
effectiveKeys.APIKey(provider) != "",
|
||||
api.hasEffectiveProviderAPIKey(ctx, configured),
|
||||
codersdk.ChatProviderConfigSourceDatabase,
|
||||
),
|
||||
)
|
||||
@@ -3968,13 +4003,16 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
resp = append(resp, codersdk.ChatProviderConfig{
|
||||
ID: uuid.Nil,
|
||||
Provider: provider,
|
||||
DisplayName: chatprovider.ProviderDisplayName(provider),
|
||||
Enabled: enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
BaseURL: effectiveKeys.BaseURL(provider),
|
||||
Source: source,
|
||||
ID: uuid.Nil,
|
||||
Provider: provider,
|
||||
DisplayName: chatprovider.ProviderDisplayName(provider),
|
||||
Enabled: enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: false,
|
||||
AllowCentralAPIKeyFallback: false,
|
||||
BaseURL: effectiveKeys.BaseURL(provider),
|
||||
Source: source,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3984,6 +4022,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
var inserted database.ChatProvider
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
@@ -4003,6 +4042,14 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
@@ -4016,14 +4063,57 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err := api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
Enabled: enabled,
|
||||
centralAPIKeyEnabled := true
|
||||
if req.CentralAPIKeyEnabled != nil {
|
||||
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
|
||||
}
|
||||
allowUserAPIKey := false
|
||||
if req.AllowUserAPIKey != nil {
|
||||
allowUserAPIKey = *req.AllowUserAPIKey
|
||||
}
|
||||
allowCentralAPIKeyFallback := false
|
||||
if req.AllowCentralAPIKeyFallback != nil {
|
||||
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
|
||||
}
|
||||
|
||||
if err := validateChatProviderCredentialPolicy(
|
||||
centralAPIKeyEnabled,
|
||||
allowUserAPIKey,
|
||||
allowCentralAPIKeyFallback,
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid credential policy.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderCentralAPIKey(
|
||||
centralAPIKeyEnabled,
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
|
||||
Provider: provider,
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
}, uuid.Nil),
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: strings.TrimSpace(req.DisplayName),
|
||||
APIKey: strings.TrimSpace(req.APIKey),
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
Enabled: enabled,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
@@ -4064,6 +4154,10 @@ func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var (
|
||||
existing database.ChatProvider
|
||||
updated database.ChatProvider
|
||||
)
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
@@ -4105,7 +4199,17 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := existing.APIKey
|
||||
apiKeyKeyID := existing.ApiKeyKeyID
|
||||
if req.APIKey != nil {
|
||||
apiKey = strings.TrimSpace(*req.APIKey)
|
||||
trimmedAPIKey := strings.TrimSpace(*req.APIKey)
|
||||
if trimmedAPIKey != "" {
|
||||
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
apiKey = trimmedAPIKey
|
||||
apiKeyKeyID = sql.NullString{}
|
||||
}
|
||||
baseURL := existing.BaseUrl
|
||||
@@ -4120,13 +4224,57 @@ func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: displayName,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: apiKeyKeyID,
|
||||
Enabled: enabled,
|
||||
ID: existing.ID,
|
||||
centralAPIKeyEnabled := existing.CentralApiKeyEnabled
|
||||
if req.CentralAPIKeyEnabled != nil {
|
||||
centralAPIKeyEnabled = *req.CentralAPIKeyEnabled
|
||||
}
|
||||
allowUserAPIKey := existing.AllowUserApiKey
|
||||
if req.AllowUserAPIKey != nil {
|
||||
allowUserAPIKey = *req.AllowUserAPIKey
|
||||
}
|
||||
allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback
|
||||
if req.AllowCentralAPIKeyFallback != nil {
|
||||
allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback
|
||||
}
|
||||
|
||||
if err := validateChatProviderCredentialPolicy(
|
||||
centralAPIKeyEnabled,
|
||||
allowUserAPIKey,
|
||||
allowCentralAPIKeyFallback,
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid credential policy.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateChatProviderCentralAPIKey(
|
||||
centralAPIKeyEnabled,
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{
|
||||
ID: existing.ID,
|
||||
Provider: existing.Provider,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
}, existing.ID),
|
||||
); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: displayName,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: apiKeyKeyID,
|
||||
Enabled: enabled,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
ID: existing.ID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -4197,6 +4345,169 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
//nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
providers, err := api.Database.GetChatProviders(chatdCtx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chat providers.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list user chat provider keys.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys))
|
||||
for _, userKey := range userKeys {
|
||||
hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true
|
||||
}
|
||||
|
||||
resp := make([]codersdk.UserChatProviderConfig, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
if !provider.Enabled || !provider.AllowUserApiKey {
|
||||
continue
|
||||
}
|
||||
hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID]
|
||||
hasCentralAPIKeyFallback := provider.Enabled &&
|
||||
provider.AllowCentralApiKeyFallback &&
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
resp = append(
|
||||
resp,
|
||||
convertUserChatProviderConfig(
|
||||
provider,
|
||||
hasUserAPIKey,
|
||||
hasCentralAPIKeyFallback,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
providerID, ok := parseChatProviderID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // Non-admin users need to validate provider availability before storing their own key.
|
||||
provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat provider.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !provider.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Provider is disabled.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !provider.AllowUserApiKey {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Provider does not allow user API keys.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.CreateUserChatProviderKeyRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
trimmedAPIKey := strings.TrimSpace(req.APIKey)
|
||||
if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if trimmedAPIKey == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: apiKey.UserID,
|
||||
ChatProviderID: providerID,
|
||||
APIKey: trimmedAPIKey,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to save user chat provider key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
hasCentralAPIKeyFallback := provider.Enabled &&
|
||||
provider.AllowCentralApiKeyFallback &&
|
||||
api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
httpapi.Write(
|
||||
ctx,
|
||||
rw,
|
||||
http.StatusOK,
|
||||
convertUserChatProviderConfig(
|
||||
provider,
|
||||
true,
|
||||
hasCentralAPIKeyFallback,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
providerID, ok := parseChatProviderID(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{
|
||||
UserID: apiKey.UserID,
|
||||
ChatProviderID: providerID,
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete user chat provider key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -4737,15 +5048,37 @@ func convertChatProviderConfig(
|
||||
}
|
||||
|
||||
return codersdk.ChatProviderConfig{
|
||||
ID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
Enabled: provider.Enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
Source: source,
|
||||
CreatedAt: provider.CreatedAt,
|
||||
UpdatedAt: provider.UpdatedAt,
|
||||
ID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
Enabled: provider.Enabled,
|
||||
HasAPIKey: hasAPIKey,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
BaseURL: strings.TrimSpace(provider.BaseUrl),
|
||||
Source: source,
|
||||
CreatedAt: provider.CreatedAt,
|
||||
UpdatedAt: provider.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func convertUserChatProviderConfig(
|
||||
provider database.ChatProvider,
|
||||
hasUserAPIKey bool,
|
||||
hasCentralAPIKeyFallback bool,
|
||||
) codersdk.UserChatProviderConfig {
|
||||
displayName := strings.TrimSpace(provider.DisplayName)
|
||||
if displayName == "" {
|
||||
displayName = chatprovider.ProviderDisplayName(provider.Provider)
|
||||
}
|
||||
|
||||
return codersdk.UserChatProviderConfig{
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
DisplayName: displayName,
|
||||
HasUserAPIKey: hasUserAPIKey,
|
||||
HasCentralAPIKeyFallback: hasCentralAPIKeyFallback,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4904,26 +5237,80 @@ func chatProviderValidationDetail() string {
|
||||
return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "."
|
||||
}
|
||||
|
||||
func chatProviderAPIKeysFromDeploymentValues(
|
||||
deploymentValues *codersdk.DeploymentValues,
|
||||
) chatprovider.ProviderAPIKeys {
|
||||
_ = deploymentValues
|
||||
// For now, we'll just manage configs in the UI.
|
||||
// We should probably not be reusing the AI bridge configs anyways.
|
||||
return chatprovider.ProviderAPIKeys{
|
||||
// OpenAI: deploymentValues.AI.BridgeConfig.OpenAI.Key.Value(),
|
||||
// Anthropic: deploymentValues.AI.BridgeConfig.Anthropic.Key.Value(),
|
||||
// BaseURLByProvider: map[string]string{
|
||||
// "openai": deploymentValues.AI.BridgeConfig.OpenAI.BaseURL.Value(),
|
||||
// "anthropic": deploymentValues.AI.BridgeConfig.Anthropic.BaseURL.Value(),
|
||||
// },
|
||||
const maxChatProviderAPIKeySize = 10240 // 10 KB
|
||||
|
||||
func validateChatProviderAPIKeySize(apiKey string) error {
|
||||
if len(apiKey) > maxChatProviderAPIKeySize {
|
||||
return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // This helper validates the explicit credential policy tuple.
|
||||
func validateChatProviderCredentialPolicy(
|
||||
centralEnabled, allowUserKey, allowFallback bool,
|
||||
) error {
|
||||
if !centralEnabled && !allowUserKey {
|
||||
return xerrors.New(
|
||||
"At least one credential source must be enabled: central API key or user API key.",
|
||||
)
|
||||
}
|
||||
if allowFallback && !centralEnabled {
|
||||
return xerrors.New(
|
||||
"Central API key fallback requires central API key to be enabled.",
|
||||
)
|
||||
}
|
||||
if allowFallback && !allowUserKey {
|
||||
return xerrors.New(
|
||||
"Central API key fallback requires user API key to be enabled.",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:revive // This helper validates central-key requirements.
|
||||
func validateChatProviderCentralAPIKey(
|
||||
centralEnabled bool,
|
||||
hasCentralAPIKey bool,
|
||||
) error {
|
||||
if centralEnabled && !hasCentralAPIKey {
|
||||
return xerrors.New(
|
||||
"API key is required when central API key is enabled.",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat
|
||||
// provider API keys.
|
||||
func ChatProviderAPIKeysFromDeploymentValues(
|
||||
_ *codersdk.DeploymentValues,
|
||||
) chatprovider.ProviderAPIKeys {
|
||||
// AI bridge deployment config is intentionally not reused for chat
|
||||
// provider credentials. Bridge keys serve the AI task subsystem and
|
||||
// should not silently broaden into chat execution paths.
|
||||
return chatprovider.ProviderAPIKeys{}
|
||||
}
|
||||
|
||||
func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool {
|
||||
return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil)
|
||||
}
|
||||
|
||||
func (api *API) hasEffectiveCentralProviderAPIKey(
|
||||
ctx context.Context,
|
||||
provider database.ChatProvider,
|
||||
excludeProviderID uuid.UUID,
|
||||
) bool {
|
||||
if !provider.CentralApiKeyEnabled {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(provider.APIKey) != "" {
|
||||
return true
|
||||
}
|
||||
deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues)
|
||||
if deploymentKeys.APIKey(provider.Provider) != "" {
|
||||
return true
|
||||
}
|
||||
if api.chatDaemon == nil {
|
||||
return false
|
||||
}
|
||||
@@ -4945,6 +5332,9 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
|
||||
[]chatprovider.ConfiguredProvider, 0, len(enabledProviders),
|
||||
)
|
||||
for _, configured := range enabledProviders {
|
||||
if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID {
|
||||
continue
|
||||
}
|
||||
enabledConfiguredProviders = append(
|
||||
enabledConfiguredProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: configured.Provider,
|
||||
@@ -4955,7 +5345,7 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
|
||||
}
|
||||
|
||||
effectiveKeys := chatprovider.MergeProviderAPIKeys(
|
||||
chatProviderAPIKeysFromDeploymentValues(api.DeploymentValues),
|
||||
deploymentKeys,
|
||||
enabledConfiguredProviders,
|
||||
)
|
||||
return effectiveKeys.APIKey(provider.Provider) != ""
|
||||
|
||||
+994
-4
File diff suppressed because it is too large
Load Diff
@@ -39,13 +39,14 @@ func TestChatParam(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.InsertChatProvider(context.Background(), database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
BaseUrl: "https://api.openai.com/v1",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
+65
-19
@@ -1585,17 +1585,17 @@ func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) e
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
|
||||
}
|
||||
if isFreshManualTitleLock(lockedChat, now) {
|
||||
// Only a fresh manual lock or a chat without a real worker should
|
||||
// block title regeneration. Running chats with a real worker may
|
||||
// regenerate their title concurrently, and last write wins.
|
||||
hasRealWorker := lockedChat.Status == database.ChatStatusRunning &&
|
||||
lockedChat.WorkerID.Valid &&
|
||||
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
|
||||
if lockedChat.Status == database.ChatStatusPending ||
|
||||
(lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) ||
|
||||
isFreshManualTitleLock(lockedChat, now) {
|
||||
return ErrManualTitleRegenerationInProgress
|
||||
}
|
||||
|
||||
// Only write the lock marker when no real worker owns WorkerID.
|
||||
// When a real worker is running, we skip the DB lock but still
|
||||
// allow regeneration. The frontend prevents same-browser
|
||||
// double-clicks, and concurrent regeneration from different
|
||||
// replicas is harmless, last write wins.
|
||||
hasRealWorker := lockedChat.WorkerID.Valid &&
|
||||
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
|
||||
if hasRealWorker {
|
||||
return nil
|
||||
}
|
||||
@@ -1658,7 +1658,7 @@ func (p *Server) RegenerateChatTitle(
|
||||
// keeping chat ownership authorization at the HTTP layer.
|
||||
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
keys, err := p.resolveProviderAPIKeys(chatdCtx)
|
||||
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err)
|
||||
}
|
||||
@@ -4808,7 +4808,7 @@ func (p *Server) resolveChatModel(
|
||||
})
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
keys, err = p.resolveProviderAPIKeys(ctx)
|
||||
keys, err = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
@@ -4830,8 +4830,9 @@ func (p *Server) resolveChatModel(
|
||||
return model, dbConfig, keys, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveProviderAPIKeys(
|
||||
func (p *Server) resolveUserProviderAPIKeys(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
providers, err := p.configCache.EnabledProviders(ctx)
|
||||
if err != nil {
|
||||
@@ -4840,17 +4841,62 @@ func (p *Server) resolveProviderAPIKeys(
|
||||
err,
|
||||
)
|
||||
}
|
||||
dbProviders := make(
|
||||
configuredProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
||||
)
|
||||
for _, provider := range providers {
|
||||
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
})
|
||||
configuredProviders = append(
|
||||
configuredProviders, chatprovider.ConfiguredProvider{
|
||||
ProviderID: provider.ID,
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
CentralAPIKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserAPIKey: provider.AllowUserApiKey,
|
||||
AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
},
|
||||
)
|
||||
}
|
||||
return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil
|
||||
allowAnyUserAPIKey := false
|
||||
for _, provider := range configuredProviders {
|
||||
if provider.AllowUserAPIKey {
|
||||
allowAnyUserAPIKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
userKeys := []chatprovider.UserProviderKey{}
|
||||
if allowAnyUserAPIKey {
|
||||
userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"get user chat provider keys: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
||||
for _, userKey := range userKeyRows {
|
||||
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
||||
ChatProviderID: userKey.ChatProviderID,
|
||||
APIKey: userKey.APIKey,
|
||||
})
|
||||
}
|
||||
}
|
||||
keys, _ := chatprovider.ResolveUserProviderKeys(
|
||||
p.providerAPIKeys,
|
||||
configuredProviders,
|
||||
userKeys,
|
||||
)
|
||||
enabledProviders := make(map[string]struct{}, len(configuredProviders))
|
||||
for _, provider := range configuredProviders {
|
||||
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
enabledProviders[normalizedProvider] = struct{}{}
|
||||
}
|
||||
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// resolveModelConfig looks up the chat's model config by its
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -99,9 +100,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
@@ -261,9 +263,10 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
@@ -378,6 +381,87 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
configCache: newChatConfigCache(
|
||||
context.Background(),
|
||||
db,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "openai-deployment-key",
|
||||
Anthropic: "anthropic-deployment-key",
|
||||
ByProvider: map[string]string{
|
||||
"openai": "openai-deployment-key",
|
||||
"anthropic": "anthropic-deployment-key",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
"openai": "https://openai.example.com",
|
||||
"anthropic": "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowCentralApiKeyFallback: true,
|
||||
}}, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys.OpenAI)
|
||||
require.Empty(t, keys.APIKey("openai"))
|
||||
require.Empty(t, keys.BaseURL("openai"))
|
||||
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
|
||||
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
|
||||
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
|
||||
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
|
||||
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
configCache: newChatConfigCache(
|
||||
context.Background(),
|
||||
db,
|
||||
quartz.NewReal(),
|
||||
),
|
||||
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "openai-deployment-key",
|
||||
ByProvider: map[string]string{
|
||||
"openai": "openai-deployment-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "openai",
|
||||
CentralApiKeyEnabled: true,
|
||||
}}, nil)
|
||||
|
||||
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "openai-deployment-key", keys.OpenAI)
|
||||
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
|
||||
}
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -523,7 +607,8 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).AnyTimes()
|
||||
conn.EXPECT().ReadFile(gomock.Any(),
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1)).Return(
|
||||
|
||||
+248
-18
@@ -2893,12 +2893,13 @@ func seedChatDependenciesWithProvider(
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: "test-key",
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
@@ -2917,6 +2918,102 @@ func seedChatDependenciesWithProvider(
|
||||
return user, model
|
||||
}
|
||||
|
||||
func seedChatDependenciesWithProviderPolicy(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
provider string,
|
||||
baseURL string,
|
||||
apiKey string,
|
||||
centralAPIKeyEnabled bool,
|
||||
allowUserAPIKey bool,
|
||||
allowCentralAPIKeyFallback bool,
|
||||
) (database.User, database.ChatProvider, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: apiKey,
|
||||
BaseUrl: baseURL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: provider,
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return user, providerConfig, model
|
||||
}
|
||||
|
||||
func waitForTerminalChatStatusEvent(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
events <-chan codersdk.ChatStreamEvent,
|
||||
) codersdk.ChatStatus {
|
||||
t.Helper()
|
||||
|
||||
var terminalStatus codersdk.ChatStatus
|
||||
testutil.Eventually(ctx, t, func(context.Context) bool {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
continue
|
||||
}
|
||||
if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError {
|
||||
terminalStatus = event.Status.Status
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
return terminalStatus
|
||||
}
|
||||
|
||||
func waitForTerminalChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
var chatResult database.Chat
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
got, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
return chatResult
|
||||
}
|
||||
|
||||
// seedWorkspaceWithAgent creates a full workspace chain with a connected
|
||||
// agent. This is the common setup needed by tests that exercise tool
|
||||
// execution against a workspace.
|
||||
@@ -2973,12 +3070,15 @@ func setOpenAIProviderBaseURL(
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -3552,12 +3652,13 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
|
||||
// Add an Anthropic provider pointing to our mock server.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3841,6 +3942,135 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) {
|
||||
"partial assistant response should contain the streamed text")
|
||||
}
|
||||
|
||||
func TestProcessChat_UserProviderKey_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
const userAPIKey = "user-test-key"
|
||||
|
||||
var authHeadersMu sync.Mutex
|
||||
authHeaders := make([]string, 0, 1)
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
authHeadersMu.Lock()
|
||||
authHeaders = append(authHeaders, req.Header.Get("Authorization"))
|
||||
authHeadersMu.Unlock()
|
||||
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("user provider key success")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("hello from the saved user key")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, provider, model := seedChatDependenciesWithProviderPolicy(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
"openai-compat",
|
||||
openAIURL,
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: userAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
creator := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "user-provider-key-success",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("say hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
_ = newActiveTestServer(t, db, ps)
|
||||
|
||||
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
|
||||
|
||||
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
||||
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
||||
require.False(t, chatResult.LastError.Valid)
|
||||
|
||||
authHeadersMu.Lock()
|
||||
recordedAuthHeaders := append([]string(nil), authHeaders...)
|
||||
authHeadersMu.Unlock()
|
||||
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
|
||||
}
|
||||
|
||||
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var llmCalls atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
llmCalls.Add(1)
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("unexpected non-streaming request")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("unexpected streaming request")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, _, model := seedChatDependenciesWithProviderPolicy(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
"openai-compat",
|
||||
openAIURL,
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
creator := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "user-provider-key-missing",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("say hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
_ = newActiveTestServer(t, db, ps)
|
||||
|
||||
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
||||
require.Equal(t, codersdk.ChatStatusError, terminalStatus)
|
||||
|
||||
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
||||
require.Equal(t, database.ChatStatusError, chatResult.Status)
|
||||
require.True(t, chatResult.LastError.Valid, "LastError should be set")
|
||||
require.NotEmpty(t, chatResult.LastError.String)
|
||||
require.NotContains(t, chatResult.LastError.String, "panicked")
|
||||
require.NotEqual(t, database.ChatStatusRunning, chatResult.Status)
|
||||
require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request")
|
||||
}
|
||||
|
||||
func TestProcessChatPanicRecovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1459,11 +1459,12 @@ func TestNulEscapeRoundTrip(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1943,11 +1944,12 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "anthropic",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "anthropic",
|
||||
DisplayName: "anthropic",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -81,11 +81,28 @@ type ProviderAPIKeys struct {
|
||||
BaseURLByProvider map[string]string
|
||||
}
|
||||
|
||||
// UserProviderKey is a user-supplied API key for a specific provider.
|
||||
type UserProviderKey struct {
|
||||
ChatProviderID uuid.UUID
|
||||
APIKey string
|
||||
}
|
||||
|
||||
// ProviderAvailability describes whether a provider has a usable
|
||||
// API key and, if not, why.
|
||||
type ProviderAvailability struct {
|
||||
Available bool
|
||||
UnavailableReason codersdk.ChatModelProviderUnavailableReason
|
||||
}
|
||||
|
||||
// ConfiguredProvider is an enabled provider loaded from database config.
|
||||
type ConfiguredProvider struct {
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
ProviderID uuid.UUID
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
CentralAPIKeyEnabled bool
|
||||
AllowUserAPIKey bool
|
||||
AllowCentralAPIKeyFallback bool
|
||||
}
|
||||
|
||||
// ConfiguredModel is an enabled model loaded from database config.
|
||||
@@ -189,21 +206,146 @@ func MergeProviderAPIKeys(fallback ProviderAPIKeys, providers []ConfiguredProvid
|
||||
return merged
|
||||
}
|
||||
|
||||
type ModelCatalog struct {
|
||||
keys ProviderAPIKeys
|
||||
// ResolveUserProviderKeys computes effective API keys and per-provider
|
||||
// availability for a given user. It considers the provider's credential
|
||||
// policy flags alongside central (DB/deployment) keys and the user's
|
||||
// personal keys.
|
||||
func ResolveUserProviderKeys(
|
||||
fallback ProviderAPIKeys,
|
||||
providers []ConfiguredProvider,
|
||||
userKeys []UserProviderKey,
|
||||
) (ProviderAPIKeys, map[string]ProviderAvailability) {
|
||||
merged := ProviderAPIKeys{
|
||||
OpenAI: strings.TrimSpace(fallback.OpenAI),
|
||||
Anthropic: strings.TrimSpace(fallback.Anthropic),
|
||||
ByProvider: map[string]string{},
|
||||
BaseURLByProvider: map[string]string{},
|
||||
}
|
||||
for provider, apiKey := range fallback.ByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if key := strings.TrimSpace(apiKey); key != "" {
|
||||
merged.ByProvider[normalizedProvider] = key
|
||||
}
|
||||
}
|
||||
for provider, baseURL := range fallback.BaseURLByProvider {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
if url := strings.TrimSpace(baseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
}
|
||||
if merged.OpenAI != "" {
|
||||
merged.ByProvider[fantasyopenai.Name] = merged.OpenAI
|
||||
}
|
||||
if merged.Anthropic != "" {
|
||||
merged.ByProvider[fantasyanthropic.Name] = merged.Anthropic
|
||||
}
|
||||
|
||||
userKeyByProviderID := make(map[uuid.UUID]string, len(userKeys))
|
||||
for _, userKey := range userKeys {
|
||||
if userKey.ChatProviderID == uuid.Nil {
|
||||
continue
|
||||
}
|
||||
if key := strings.TrimSpace(userKey.APIKey); key != "" {
|
||||
userKeyByProviderID[userKey.ChatProviderID] = key
|
||||
}
|
||||
}
|
||||
|
||||
availabilityByProvider := make(map[string]ProviderAvailability, len(providers))
|
||||
for _, provider := range providers {
|
||||
normalizedProvider := NormalizeProvider(provider.Provider)
|
||||
if normalizedProvider == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if url := strings.TrimSpace(provider.BaseURL); url != "" {
|
||||
merged.BaseURLByProvider[normalizedProvider] = url
|
||||
}
|
||||
|
||||
var userKey string
|
||||
if provider.ProviderID != uuid.Nil {
|
||||
userKey = userKeyByProviderID[provider.ProviderID]
|
||||
}
|
||||
|
||||
var centralKey string
|
||||
if provider.CentralAPIKeyEnabled {
|
||||
if key := strings.TrimSpace(provider.APIKey); key != "" {
|
||||
centralKey = key
|
||||
} else {
|
||||
centralKey = fallback.APIKey(normalizedProvider)
|
||||
}
|
||||
}
|
||||
|
||||
resolved := ProviderAvailability{}
|
||||
chosenKey := ""
|
||||
switch {
|
||||
case provider.AllowUserAPIKey && userKey != "":
|
||||
chosenKey = userKey
|
||||
resolved.Available = true
|
||||
case centralKey != "":
|
||||
if !provider.AllowUserAPIKey || provider.AllowCentralAPIKeyFallback {
|
||||
chosenKey = centralKey
|
||||
resolved.Available = true
|
||||
} else {
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
}
|
||||
case provider.AllowUserAPIKey && provider.AllowCentralAPIKeyFallback && provider.CentralAPIKeyEnabled:
|
||||
// When users can add their own key, a missing central fallback key is
|
||||
// still something the user can remedy.
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
case provider.AllowUserAPIKey:
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired
|
||||
default:
|
||||
resolved.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
}
|
||||
|
||||
setResolvedProviderAPIKey(&merged, normalizedProvider, chosenKey)
|
||||
availabilityByProvider[normalizedProvider] = resolved
|
||||
}
|
||||
|
||||
return merged, availabilityByProvider
|
||||
}
|
||||
|
||||
func NewModelCatalog(keys ProviderAPIKeys) *ModelCatalog {
|
||||
return &ModelCatalog{
|
||||
keys: keys,
|
||||
func setResolvedProviderAPIKey(keys *ProviderAPIKeys, provider string, apiKey string) {
|
||||
normalizedProvider := NormalizeProvider(provider)
|
||||
if normalizedProvider == "" {
|
||||
return
|
||||
}
|
||||
if keys.ByProvider == nil {
|
||||
keys.ByProvider = map[string]string{}
|
||||
}
|
||||
|
||||
delete(keys.ByProvider, normalizedProvider)
|
||||
trimmedKey := strings.TrimSpace(apiKey)
|
||||
switch normalizedProvider {
|
||||
case fantasyopenai.Name:
|
||||
keys.OpenAI = trimmedKey
|
||||
case fantasyanthropic.Name:
|
||||
keys.Anthropic = trimmedKey
|
||||
}
|
||||
if trimmedKey != "" {
|
||||
keys.ByProvider[normalizedProvider] = trimmedKey
|
||||
}
|
||||
}
|
||||
|
||||
type ModelCatalog struct{}
|
||||
|
||||
func NewModelCatalog() *ModelCatalog {
|
||||
return &ModelCatalog{}
|
||||
}
|
||||
|
||||
// ListConfiguredModels returns a model catalog from enabled DB-backed model
|
||||
// configs. The second return value reports whether DB-backed models were used.
|
||||
func (c *ModelCatalog) ListConfiguredModels(
|
||||
func (*ModelCatalog) ListConfiguredModels(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
configuredModels []ConfiguredModel,
|
||||
availabilityByProvider map[string]ProviderAvailability,
|
||||
enabledProviders map[string]struct{},
|
||||
) (codersdk.ChatModelsResponse, bool) {
|
||||
if len(configuredModels) == 0 {
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
@@ -247,11 +389,14 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
return codersdk.ChatModelsResponse{}, false
|
||||
}
|
||||
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(providers)),
|
||||
}
|
||||
for _, provider := range providers {
|
||||
if _, ok := enabledProviders[provider]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
models := modelsByProvider[provider]
|
||||
sortChatModels(models)
|
||||
|
||||
@@ -259,11 +404,14 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
Provider: provider,
|
||||
Models: models,
|
||||
}
|
||||
if keys.APIKey(provider) == "" {
|
||||
if avail, ok := availabilityByProvider[provider]; ok {
|
||||
result.Available = avail.Available
|
||||
if !avail.Available {
|
||||
result.UnavailableReason = avail.UnavailableReason
|
||||
}
|
||||
} else {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
@@ -273,25 +421,32 @@ func (c *ModelCatalog) ListConfiguredModels(
|
||||
}
|
||||
|
||||
// ListConfiguredProviderAvailability returns provider availability derived from
|
||||
// deployment/env keys merged with enabled DB provider keys.
|
||||
func (c *ModelCatalog) ListConfiguredProviderAvailability(
|
||||
configuredProviders []ConfiguredProvider,
|
||||
// the policy-aware availability map for enabled providers.
|
||||
func (*ModelCatalog) ListConfiguredProviderAvailability(
|
||||
availabilityByProvider map[string]ProviderAvailability,
|
||||
enabledProviders map[string]struct{},
|
||||
) codersdk.ChatModelsResponse {
|
||||
keys := MergeProviderAPIKeys(c.keys, configuredProviders)
|
||||
response := codersdk.ChatModelsResponse{
|
||||
Providers: make([]codersdk.ChatModelProvider, 0, len(supportedProviderNames)),
|
||||
}
|
||||
|
||||
for _, provider := range supportedProviderNames {
|
||||
if _, ok := enabledProviders[provider]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result := codersdk.ChatModelProvider{
|
||||
Provider: provider,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}
|
||||
if keys.APIKey(provider) == "" {
|
||||
if avail, ok := availabilityByProvider[provider]; ok {
|
||||
result.Available = avail.Available
|
||||
if !avail.Available {
|
||||
result.UnavailableReason = avail.UnavailableReason
|
||||
}
|
||||
} else {
|
||||
result.Available = false
|
||||
result.UnavailableReason = codersdk.ChatModelProviderUnavailableMissingAPIKey
|
||||
} else {
|
||||
result.Available = true
|
||||
}
|
||||
|
||||
response.Providers = append(response.Providers, result)
|
||||
@@ -300,6 +455,27 @@ func (c *ModelCatalog) ListConfiguredProviderAvailability(
|
||||
return response
|
||||
}
|
||||
|
||||
// PruneDisabledProviderKeys removes entries from keys that do not
|
||||
// belong to an enabled provider. It clears ByProvider and
|
||||
// BaseURLByProvider entries for disabled providers and zeroes the
|
||||
// legacy OpenAI and Anthropic fields when those providers are not
|
||||
// enabled.
|
||||
func PruneDisabledProviderKeys(keys *ProviderAPIKeys, enabledProviders map[string]struct{}) {
|
||||
for provider := range keys.ByProvider {
|
||||
if _, ok := enabledProviders[provider]; ok {
|
||||
continue
|
||||
}
|
||||
delete(keys.ByProvider, provider)
|
||||
delete(keys.BaseURLByProvider, provider)
|
||||
}
|
||||
if _, ok := enabledProviders[NormalizeProvider("openai")]; !ok {
|
||||
keys.OpenAI = ""
|
||||
}
|
||||
if _, ok := enabledProviders[NormalizeProvider("anthropic")]; !ok {
|
||||
keys.Anthropic = ""
|
||||
}
|
||||
}
|
||||
|
||||
func newChatModel(provider, modelID, displayName string) codersdk.ChatModel {
|
||||
name := strings.TrimSpace(displayName)
|
||||
if name == "" {
|
||||
|
||||
@@ -21,6 +21,166 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestResolveUserProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configuredProvider := func(id uuid.UUID, provider string, centralEnabled bool, centralKey string, allowUser bool, allowCentralFallback bool) chatprovider.ConfiguredProvider {
|
||||
return chatprovider.ConfiguredProvider{
|
||||
ProviderID: id,
|
||||
Provider: provider,
|
||||
APIKey: centralKey,
|
||||
CentralAPIKeyEnabled: centralEnabled,
|
||||
AllowUserAPIKey: allowUser,
|
||||
AllowCentralAPIKeyFallback: allowCentralFallback,
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKey := func(id uuid.UUID, apiKey string) chatprovider.UserProviderKey {
|
||||
return chatprovider.UserProviderKey{
|
||||
ChatProviderID: id,
|
||||
APIKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
openAIProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
anthropicProviderID := uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback chatprovider.ProviderAPIKeys
|
||||
providers []chatprovider.ConfiguredProvider
|
||||
userKeys []chatprovider.UserProviderKey
|
||||
wantAvailability map[string]chatprovider.ProviderAvailability
|
||||
wantKeys map[string]string
|
||||
}{
|
||||
{
|
||||
name: "CentralOnlyKeyPresent",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CentralOnlyKeyMissing",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", false, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserOnlyUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserOnlyUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, false, "sk-central", true, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOffUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOffUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, false)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnUserHasKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
|
||||
userKeys: []chatprovider.UserProviderKey{userProviderKey(openAIProviderID, "sk-user")},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", true, true)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BothEnabledFallbackOnCentralKeyEmptyUserHasNoKey",
|
||||
providers: []chatprovider.ConfiguredProvider{configuredProvider(openAIProviderID, fantasyopenai.Name, true, "", true, true)},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultipleProvidersDifferentPolicies",
|
||||
providers: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(openAIProviderID, fantasyopenai.Name, true, "sk-central", false, false),
|
||||
configuredProvider(anthropicProviderID, fantasyanthropic.Name, false, "", true, false),
|
||||
},
|
||||
wantAvailability: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {Available: true},
|
||||
fantasyanthropic.Name: {Available: false, UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired},
|
||||
},
|
||||
wantKeys: map[string]string{
|
||||
fantasyopenai.Name: "sk-central",
|
||||
fantasyanthropic.Name: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys, availability := chatprovider.ResolveUserProviderKeys(tt.fallback, tt.providers, tt.userKeys)
|
||||
|
||||
require.Len(t, availability, len(tt.wantAvailability))
|
||||
for provider, wantAvailability := range tt.wantAvailability {
|
||||
gotAvailability, ok := availability[provider]
|
||||
require.True(t, ok, "expected availability for provider %q", provider)
|
||||
require.Equal(t, wantAvailability, gotAvailability)
|
||||
require.Equal(t, tt.wantKeys[provider], keys.APIKey(provider))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -91,6 +251,413 @@ func TestReasoningEffortFromChat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUserProviderKeys_UnavailableReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
provider chatprovider.ConfiguredProvider
|
||||
wantReason codersdk.ChatModelProviderUnavailableReason
|
||||
}{
|
||||
{
|
||||
name: "FallbackConfiguredWithoutCentralKeyReturnsUserAPIKeyRequired",
|
||||
provider: chatprovider.ConfiguredProvider{
|
||||
Provider: "anthropic",
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: true,
|
||||
AllowCentralAPIKeyFallback: true,
|
||||
},
|
||||
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
{
|
||||
name: "UserKeyRequiredWithoutFallback",
|
||||
provider: chatprovider.ConfiguredProvider{
|
||||
Provider: "anthropic",
|
||||
CentralAPIKeyEnabled: true,
|
||||
AllowUserAPIKey: true,
|
||||
},
|
||||
wantReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys, availability := chatprovider.ResolveUserProviderKeys(
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
[]chatprovider.ConfiguredProvider{tt.provider},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, keys.APIKey(tt.provider.Provider))
|
||||
resolved, ok := availability[tt.provider.Provider]
|
||||
require.True(t, ok)
|
||||
require.False(t, resolved.Available)
|
||||
require.Equal(t, tt.wantReason, resolved.UnavailableReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConfiguredModels_PolicyAwareAvailability(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configuredProvider := func(provider string, apiKey string) chatprovider.ConfiguredProvider {
|
||||
return chatprovider.ConfiguredProvider{
|
||||
ProviderID: uuid.New(),
|
||||
Provider: provider,
|
||||
APIKey: apiKey,
|
||||
}
|
||||
}
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
tests := []struct {
|
||||
name string
|
||||
configuredProviders []chatprovider.ConfiguredProvider
|
||||
configuredModels []chatprovider.ConfiguredModel
|
||||
availabilityByProvider map[string]chatprovider.ProviderAvailability
|
||||
enabledProviders map[string]struct{}
|
||||
want codersdk.ChatModelsResponse
|
||||
}{
|
||||
{
|
||||
name: "PolicyUnavailableOverridesConfiguredKey",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyopenai.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
}},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyopenai.Name: {
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "gpt-4",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "PolicyAvailableMarksProviderAvailable",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyanthropic.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Model: "claude-3-5-sonnet",
|
||||
}},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyanthropic.Name + ":claude-3-5-sonnet",
|
||||
Provider: fantasyanthropic.Name,
|
||||
Model: "claude-3-5-sonnet",
|
||||
DisplayName: "claude-3-5-sonnet",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "DisabledProviderOmitted",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyanthropic.Name, "sk-anthropic"),
|
||||
configuredProvider(fantasyopenai.Name, "sk-openai"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{
|
||||
{Provider: fantasyanthropic.Name, Model: "claude-3-5-sonnet"},
|
||||
{Provider: fantasyopenai.Name, Model: "gpt-4"},
|
||||
},
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "gpt-4",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "MissingAvailabilityDefaultsToMissingAPIKey",
|
||||
configuredProviders: []chatprovider.ConfiguredProvider{
|
||||
configuredProvider(fantasyopenai.Name, "sk-central"),
|
||||
},
|
||||
configuredModels: []chatprovider.ConfiguredModel{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4o",
|
||||
}},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
|
||||
Models: []codersdk.ChatModel{{
|
||||
ID: fantasyopenai.Name + ":gpt-4o",
|
||||
Provider: fantasyopenai.Name,
|
||||
Model: "gpt-4o",
|
||||
DisplayName: "gpt-4o",
|
||||
}},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := catalog.ListConfiguredModels(
|
||||
tt.configuredProviders,
|
||||
tt.configuredModels,
|
||||
tt.availabilityByProvider,
|
||||
tt.enabledProviders,
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConfiguredProviderAvailability_PolicyAwareFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
catalog := chatprovider.NewModelCatalog()
|
||||
tests := []struct {
|
||||
name string
|
||||
availabilityByProvider map[string]chatprovider.ProviderAvailability
|
||||
enabledProviders map[string]struct{}
|
||||
want codersdk.ChatModelsResponse
|
||||
}{
|
||||
{
|
||||
name: "EnabledProvidersUsePolicyAvailability",
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name, fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{
|
||||
{
|
||||
Provider: fantasyanthropic.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired,
|
||||
Models: []codersdk.ChatModel{},
|
||||
},
|
||||
{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "DisabledSupportedProviderOmitted",
|
||||
availabilityByProvider: map[string]chatprovider.ProviderAvailability{
|
||||
fantasyanthropic.Name: {Available: true},
|
||||
fantasyopenai.Name: {Available: true},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: true,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}}},
|
||||
},
|
||||
{
|
||||
name: "MissingAvailabilityDefaultsToMissingAPIKey",
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: codersdk.ChatModelsResponse{Providers: []codersdk.ChatModelProvider{{
|
||||
Provider: fantasyopenai.Name,
|
||||
Available: false,
|
||||
UnavailableReason: codersdk.ChatModelProviderUnavailableMissingAPIKey,
|
||||
Models: []codersdk.ChatModel{},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := catalog.ListConfiguredProviderAvailability(
|
||||
tt.availabilityByProvider,
|
||||
tt.enabledProviders,
|
||||
)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneDisabledProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
enabledProviders := func(providers ...string) map[string]struct{} {
|
||||
result := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
result[chatprovider.NormalizeProvider(provider)] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keys chatprovider.ProviderAPIKeys
|
||||
enabledProviders map[string]struct{}
|
||||
want chatprovider.ProviderAPIKeys
|
||||
}{
|
||||
{
|
||||
name: "DisabledProviderEntriesRemoved",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OpenAIDisabledClearsLegacyField",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyanthropic.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AnthropicDisabledClearsLegacyField",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AllEnabledLeavesKeysUnchanged",
|
||||
keys: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
enabledProviders: enabledProviders(fantasyopenai.Name, fantasyanthropic.Name),
|
||||
want: chatprovider.ProviderAPIKeys{
|
||||
OpenAI: "sk-openai",
|
||||
Anthropic: "sk-anthropic",
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenai.Name: "sk-openai",
|
||||
fantasyanthropic.Name: "sk-anthropic",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenai.Name: "https://openai.example.com",
|
||||
fantasyanthropic.Name: "https://anthropic.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keys := tt.keys
|
||||
chatprovider.PruneDisabledProviderKeys(&keys, tt.enabledProviders)
|
||||
require.Equal(t, tt.want, keys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCoderHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -440,13 +440,14 @@ func seedModelConfig(
|
||||
t.Helper()
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -122,13 +122,14 @@ func seedInternalChatDeps(
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -959,9 +959,10 @@ func TestWorker(t *testing.T) {
|
||||
|
||||
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
+86
-18
@@ -405,6 +405,8 @@ type ChatModelProviderUnavailableReason string
|
||||
const (
|
||||
ChatModelProviderUnavailableMissingAPIKey ChatModelProviderUnavailableReason = "missing_api_key"
|
||||
ChatModelProviderUnavailableFetchFailed ChatModelProviderUnavailableReason = "fetch_failed"
|
||||
// #nosec G101
|
||||
ChatModelProviderUnavailableReasonUserAPIKeyRequired ChatModelProviderUnavailableReason = "user_api_key_required"
|
||||
)
|
||||
|
||||
// ChatModel represents a model in the chat model catalog.
|
||||
@@ -532,32 +534,57 @@ const (
|
||||
|
||||
// ChatProviderConfig is an admin-managed provider configuration.
|
||||
type ChatProviderConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HasAPIKey bool `json:"has_api_key"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Source ChatProviderConfigSource `json:"source"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
HasAPIKey bool `json:"has_api_key"`
|
||||
CentralAPIKeyEnabled bool `json:"central_api_key_enabled"`
|
||||
AllowUserAPIKey bool `json:"allow_user_api_key"`
|
||||
AllowCentralAPIKeyFallback bool `json:"allow_central_api_key_fallback"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Source ChatProviderConfigSource `json:"source"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// CreateChatProviderConfigRequest creates a chat provider config.
|
||||
type CreateChatProviderConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"`
|
||||
AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"`
|
||||
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatProviderConfigRequest updates a chat provider config.
|
||||
type UpdateChatProviderConfigRequest struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
APIKey *string `json:"api_key,omitempty"`
|
||||
BaseURL *string `json:"base_url,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
CentralAPIKeyEnabled *bool `json:"central_api_key_enabled,omitempty"`
|
||||
AllowUserAPIKey *bool `json:"allow_user_api_key,omitempty"`
|
||||
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
|
||||
}
|
||||
|
||||
// UserChatProviderConfig is a summary of a provider that allows
|
||||
// user-supplied keys, as seen from the current user's perspective.
|
||||
type UserChatProviderConfig struct {
|
||||
ProviderID uuid.UUID `json:"provider_id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
DisplayName string `json:"display_name"`
|
||||
HasUserAPIKey bool `json:"has_user_api_key"`
|
||||
HasCentralAPIKeyFallback bool `json:"has_central_api_key_fallback"`
|
||||
}
|
||||
|
||||
// CreateUserChatProviderKeyRequest creates or replaces a user's API key
|
||||
// for a provider.
|
||||
type CreateUserChatProviderKeyRequest struct {
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
// ChatModelConfig is an admin-managed model configuration.
|
||||
@@ -1332,6 +1359,47 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUserChatProviderConfigs returns user-scoped chat provider configs.
|
||||
func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list user chat provider configs: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var configs []UserChatProviderConfig
|
||||
return configs, json.NewDecoder(res.Body).Decode(&configs)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey creates or replaces a user API key for a provider.
|
||||
func (c *ExperimentalClient) UpsertUserChatProviderKey(ctx context.Context, providerID uuid.UUID, req CreateUserChatProviderKeyRequest) (UserChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), req)
|
||||
if err != nil {
|
||||
return UserChatProviderConfig{}, xerrors.Errorf("upsert user chat provider key: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserChatProviderConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
var config UserChatProviderConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// DeleteUserChatProviderKey deletes a user API key for a provider.
|
||||
func (c *ExperimentalClient) DeleteUserChatProviderKey(ctx context.Context, providerID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/user-provider-configs/%s", providerID), nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete user chat provider key: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListChatModelConfigs returns admin-managed chat model configs.
|
||||
func (c *ExperimentalClient) ListChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/model-configs", nil)
|
||||
|
||||
@@ -104,13 +104,14 @@ func seedChatDependencies(
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: safetyNet.URL,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: safetyNet.URL,
|
||||
CentralApiKeyEnabled: true,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
@@ -186,12 +187,15 @@ func setOpenAIProviderBaseURL(
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: baseURL,
|
||||
CentralApiKeyEnabled: true,
|
||||
AllowUserApiKey: false,
|
||||
AllowCentralApiKeyFallback: false,
|
||||
ApiKeyKeyID: provider.ApiKeyKeyID,
|
||||
Enabled: provider.Enabled,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -73,12 +73,35 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if strings.TrimSpace(userProviderKey.APIKey) == "" {
|
||||
continue
|
||||
}
|
||||
if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
return nil
|
||||
}, &database.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update user links: %w", err)
|
||||
return xerrors.Errorf("update user tokens and chat provider keys: %w", err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
@@ -97,12 +120,15 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
@@ -189,12 +215,32 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
|
||||
}
|
||||
}
|
||||
|
||||
userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err)
|
||||
}
|
||||
for _, userProviderKey := range userProviderKeys {
|
||||
if !userProviderKey.ApiKeyKeyID.Valid {
|
||||
log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{
|
||||
UserID: userProviderKey.UserID,
|
||||
ChatProviderID: userProviderKey.ChatProviderID,
|
||||
APIKey: userProviderKey.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1))
|
||||
}
|
||||
return nil
|
||||
}, &database.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update user links: %w", err)
|
||||
return xerrors.Errorf("update user tokens and chat provider keys: %w", err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
@@ -209,12 +255,15 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
||||
continue
|
||||
}
|
||||
if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
Enabled: provider.Enabled,
|
||||
ID: provider.ID,
|
||||
DisplayName: provider.DisplayName,
|
||||
APIKey: provider.APIKey,
|
||||
BaseUrl: provider.BaseUrl,
|
||||
ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id
|
||||
Enabled: provider.Enabled,
|
||||
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
|
||||
AllowUserApiKey: provider.AllowUserApiKey,
|
||||
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
|
||||
ID: provider.ID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err)
|
||||
}
|
||||
@@ -241,6 +290,8 @@ DELETE FROM user_links
|
||||
DELETE FROM external_auth_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL;
|
||||
DELETE FROM user_chat_provider_keys
|
||||
WHERE api_key_key_id IS NOT NULL;
|
||||
UPDATE chat_providers
|
||||
SET api_key = '',
|
||||
api_key_key_id = NULL
|
||||
|
||||
@@ -471,6 +471,57 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
|
||||
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range keys {
|
||||
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
if strings.TrimSpace(params.APIKey) == "" {
|
||||
params.ApiKeyKeyID = sql.NullString{}
|
||||
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
|
||||
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
||||
return database.UserChatProviderKey{}, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// decryptMCPServerConfig decrypts all encrypted fields on a
|
||||
// single MCPServerConfig in place.
|
||||
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
||||
|
||||
@@ -1177,3 +1177,113 @@ func TestMCPServerUserTokens(t *testing.T) {
|
||||
requireEncryptedEquals(t, ciphers[0], rawTok.RefreshToken, refreshToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserChatProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
//nolint:gosec // test credentials
|
||||
initialAPIKey = "sk-initial-api-key-value"
|
||||
//nolint:gosec // test credentials
|
||||
updatedAPIKey = "sk-updated-api-key-value"
|
||||
)
|
||||
|
||||
insertProviderAndKey := func(
|
||||
t *testing.T,
|
||||
crypt *dbCrypt,
|
||||
ciphers []Cipher,
|
||||
) (database.ChatProvider, database.UserChatProviderKey) {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
provider, err := crypt.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "",
|
||||
Enabled: true,
|
||||
AllowUserApiKey: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: user.ID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: initialAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialAPIKey, key.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String)
|
||||
return provider, key
|
||||
}
|
||||
|
||||
getUserChatProviderKey := func(t *testing.T, store interface {
|
||||
GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error)
|
||||
}, userID uuid.UUID, providerID uuid.UUID,
|
||||
) database.UserChatProviderKey {
|
||||
t.Helper()
|
||||
keys, err := store.GetUserChatProviderKeys(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, providerID, keys[0].ChatProviderID)
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, key.ID, got.ID)
|
||||
require.Equal(t, initialAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, initialAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey)
|
||||
})
|
||||
|
||||
t.Run("GetUserChatProviderKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, crypt, ciphers := setup(t)
|
||||
_, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, key.ID, keys[0].ID)
|
||||
require.Equal(t, initialAPIKey, keys[0].APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String)
|
||||
})
|
||||
|
||||
t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
provider, key := insertProviderAndKey(t, crypt, ciphers)
|
||||
|
||||
updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
|
||||
UserID: key.UserID,
|
||||
ChatProviderID: provider.ID,
|
||||
APIKey: updatedAPIKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key.ID, updated.ID)
|
||||
require.Equal(t, key.CreatedAt, updated.CreatedAt)
|
||||
require.False(t, updated.UpdatedAt.Before(key.UpdatedAt))
|
||||
require.Equal(t, updatedAPIKey, updated.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String)
|
||||
|
||||
got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID)
|
||||
require.Equal(t, updatedAPIKey, got.APIKey)
|
||||
require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String)
|
||||
|
||||
keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, updatedAPIKey, keys[0].APIKey)
|
||||
|
||||
rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID)
|
||||
require.NotEqual(t, updatedAPIKey, rawKey.APIKey)
|
||||
requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey)
|
||||
})
|
||||
}
|
||||
|
||||
Generated
+34
-2
@@ -1776,10 +1776,11 @@ export interface ChatModelProviderOptions {
|
||||
// From codersdk/chats.go
|
||||
export type ChatModelProviderUnavailableReason =
|
||||
| "fetch_failed"
|
||||
| "missing_api_key";
|
||||
| "missing_api_key"
|
||||
| "user_api_key_required";
|
||||
|
||||
export const ChatModelProviderUnavailableReasons: ChatModelProviderUnavailableReason[] =
|
||||
["fetch_failed", "missing_api_key"];
|
||||
["fetch_failed", "missing_api_key", "user_api_key_required"];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
@@ -1836,6 +1837,9 @@ export interface ChatProviderConfig {
|
||||
readonly display_name: string;
|
||||
readonly enabled: boolean;
|
||||
readonly has_api_key: boolean;
|
||||
readonly central_api_key_enabled: boolean;
|
||||
readonly allow_user_api_key: boolean;
|
||||
readonly allow_central_api_key_fallback: boolean;
|
||||
readonly base_url?: string;
|
||||
readonly source: ChatProviderConfigSource;
|
||||
readonly created_at?: string;
|
||||
@@ -2362,6 +2366,9 @@ export interface CreateChatProviderConfigRequest {
|
||||
readonly api_key?: string;
|
||||
readonly base_url?: string;
|
||||
readonly enabled?: boolean;
|
||||
readonly central_api_key_enabled?: boolean;
|
||||
readonly allow_user_api_key?: boolean;
|
||||
readonly allow_central_api_key_fallback?: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -2643,6 +2650,15 @@ export interface CreateTokenRequest {
|
||||
readonly allow_list?: readonly APIAllowListTarget[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* CreateUserChatProviderKeyRequest creates or replaces a user's API key
|
||||
* for a provider.
|
||||
*/
|
||||
export interface CreateUserChatProviderKeyRequest {
|
||||
readonly api_key: string;
|
||||
}
|
||||
|
||||
// From codersdk/users.go
|
||||
export interface CreateUserRequestWithOrgs {
|
||||
readonly email: string;
|
||||
@@ -7150,6 +7166,9 @@ export interface UpdateChatProviderConfigRequest {
|
||||
readonly api_key?: string;
|
||||
readonly base_url?: string;
|
||||
readonly enabled?: boolean;
|
||||
readonly central_api_key_enabled?: boolean;
|
||||
readonly allow_user_api_key?: boolean;
|
||||
readonly allow_central_api_key_fallback?: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -7694,6 +7713,19 @@ export interface UserChatCustomPrompt {
|
||||
readonly custom_prompt: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UserChatProviderConfig is a summary of a provider that allows
|
||||
* user-supplied keys, as seen from the current user's perspective.
|
||||
*/
|
||||
export interface UserChatProviderConfig {
|
||||
readonly provider_id: string;
|
||||
readonly provider: string;
|
||||
readonly display_name: string;
|
||||
readonly has_user_api_key: boolean;
|
||||
readonly has_central_api_key_fallback: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/insights.go
|
||||
/**
|
||||
* UserLatency shows the connection latency for a user.
|
||||
|
||||
+4
@@ -20,6 +20,10 @@ const createProviderConfig = (
|
||||
display_name: overrides.display_name ?? "",
|
||||
enabled: overrides.enabled ?? true,
|
||||
has_api_key: overrides.has_api_key ?? false,
|
||||
central_api_key_enabled: overrides.central_api_key_enabled ?? true,
|
||||
allow_user_api_key: overrides.allow_user_api_key ?? false,
|
||||
allow_central_api_key_fallback:
|
||||
overrides.allow_central_api_key_fallback ?? false,
|
||||
base_url: overrides.base_url ?? "",
|
||||
source: overrides.source ?? "database",
|
||||
created_at: overrides.created_at ?? now,
|
||||
|
||||
@@ -14,6 +14,9 @@ const providerState: ProviderState = {
|
||||
display_name: "OpenAI",
|
||||
enabled: true,
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
base_url: undefined,
|
||||
source: "database",
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
|
||||
Reference in New Issue
Block a user