mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add admin-configurable chat title generation model (#24838)
Adds an admin-configurable deployment-wide setting that controls which
model is used for chat title generation. Admins can pick any enabled
chat model config from the Agents settings page, or leave the setting
unset to keep the existing fast-models-then-chat-model fallback
algorithm.
When a model is selected, both automatic and manual title generation use
only that model, with no silent fallback. When the configured model is
disabled, missing credentials, or otherwise unusable, automatic title
generation skips entirely (best-effort) and manual title regeneration
returns a clear error, so admins notice the misconfiguration instead of
silently routing title traffic through another provider.
## Surface
- New deployment-wide setting stored as a `site_configs` row
(`agents_chat_title_generation_model_override`).
- New experimental endpoint `GET/PUT
/api/experimental/chats/config/model-override/{context}`.
- Frontend: title generation now appears as a third dropdown on the
Agents admin settings page alongside the existing general and explore
context overrides.
## DRY refactors folded in
Title generation is integrated as a third value of the existing
`ChatModelOverrideContext` type alongside `general` and `explore`,
sharing the parameterized HTTP route, SDK methods, generated types, and
frontend API plumbing rather than introducing a parallel surface. The
`Agent` prefix was dropped from the type and route since title
generation is not a delegated agent.
The chatd model-override resolver is also shared.
`resolveConfiguredModelOverride` now takes a `failureMode` parameter:
- Subagent overrides use soft failure: misconfigured overrides are
logged and the parent model is used.
- Title generation uses hard failure: misconfigured overrides return an
explicit error so manual title regeneration surfaces the
misconfiguration and automatic title generation skips instead of
silently falling back.
> Mux is acting on Mike's behalf.
This commit is contained in:
+2
-2
@@ -1192,8 +1192,8 @@ func New(options *Options) *API {
|
|||||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||||
r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions)
|
r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions)
|
||||||
r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions)
|
r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions)
|
||||||
r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride)
|
r.Get("/model-override/{context}", api.getChatModelOverride)
|
||||||
r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride)
|
r.Put("/model-override/{context}", api.putChatModelOverride)
|
||||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||||
r.Get("/debug-logging", api.getChatDebugLogging)
|
r.Get("/debug-logging", api.getChatDebugLogging)
|
||||||
|
|||||||
@@ -2967,6 +2967,13 @@ func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
|||||||
return q.db.GetChatTemplateAllowlist(ctx)
|
return q.db.GetChatTemplateAllowlist(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *querier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||||
|
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return q.db.GetChatTitleGenerationModelOverride(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||||
return database.ChatUsageLimitConfig{}, err
|
return database.ChatUsageLimitConfig{}, err
|
||||||
@@ -7517,6 +7524,13 @@ func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllow
|
|||||||
return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
|
return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *querier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||||
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return q.db.UpsertChatTitleGenerationModelOverride(ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||||
return database.ChatUsageLimitConfig{}, err
|
return database.ChatUsageLimitConfig{}, err
|
||||||
|
|||||||
@@ -918,6 +918,10 @@ func (s *MethodTestSuite) TestChats() {
|
|||||||
dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||||
}))
|
}))
|
||||||
|
s.Run("GetChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||||
|
dbm.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||||
|
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||||
|
}))
|
||||||
s.Run("GetChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
s.Run("GetChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||||
dbm.EXPECT().GetChatPlanModeInstructions(gomock.Any()).Return("", nil).AnyTimes()
|
dbm.EXPECT().GetChatPlanModeInstructions(gomock.Any()).Return("", nil).AnyTimes()
|
||||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||||
@@ -1237,6 +1241,10 @@ func (s *MethodTestSuite) TestChats() {
|
|||||||
dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||||
}))
|
}))
|
||||||
|
s.Run("UpsertChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||||
|
dbm.EXPECT().UpsertChatTitleGenerationModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||||
|
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||||
|
}))
|
||||||
s.Run("UpsertChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
s.Run("UpsertChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||||
dbm.EXPECT().UpsertChatPlanModeInstructions(gomock.Any(), "").Return(nil).AnyTimes()
|
dbm.EXPECT().UpsertChatPlanModeInstructions(gomock.Any(), "").Return(nil).AnyTimes()
|
||||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||||
|
|||||||
@@ -1456,6 +1456,14 @@ func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string
|
|||||||
return r0, r1
|
return r0, r1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m queryMetricsStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||||
|
start := time.Now()
|
||||||
|
r0, r1 := m.s.GetChatTitleGenerationModelOverride(ctx)
|
||||||
|
m.queryLatencies.WithLabelValues("GetChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds())
|
||||||
|
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTitleGenerationModelOverride").Inc()
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||||
@@ -5408,6 +5416,14 @@ func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, temp
|
|||||||
return r0
|
return r0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m queryMetricsStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||||
|
start := time.Now()
|
||||||
|
r0 := m.s.UpsertChatTitleGenerationModelOverride(ctx, value)
|
||||||
|
m.queryLatencies.WithLabelValues("UpsertChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds())
|
||||||
|
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTitleGenerationModelOverride").Inc()
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||||
|
|||||||
@@ -2687,6 +2687,21 @@ func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetChatTitleGenerationModelOverride mocks base method.
|
||||||
|
func (m *MockStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetChatTitleGenerationModelOverride", ctx)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChatTitleGenerationModelOverride indicates an expected call of GetChatTitleGenerationModelOverride.
|
||||||
|
func (mr *MockStoreMockRecorder) GetChatTitleGenerationModelOverride(ctx any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatTitleGenerationModelOverride), ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// GetChatUsageLimitConfig mocks base method.
|
// GetChatUsageLimitConfig mocks base method.
|
||||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -10152,6 +10167,20 @@ func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowl
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpsertChatTitleGenerationModelOverride mocks base method.
|
||||||
|
func (m *MockStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UpsertChatTitleGenerationModelOverride", ctx, value)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertChatTitleGenerationModelOverride indicates an expected call of UpsertChatTitleGenerationModelOverride.
|
||||||
|
func (mr *MockStoreMockRecorder) UpsertChatTitleGenerationModelOverride(ctx, value any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatTitleGenerationModelOverride), ctx, value)
|
||||||
|
}
|
||||||
|
|
||||||
// UpsertChatUsageLimitConfig mocks base method.
|
// UpsertChatUsageLimitConfig mocks base method.
|
||||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -362,6 +362,7 @@ type sqlcQuerier interface {
|
|||||||
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||||
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||||
GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
||||||
|
GetChatTitleGenerationModelOverride(ctx context.Context) (string, error)
|
||||||
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
||||||
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
||||||
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
||||||
@@ -1206,6 +1207,7 @@ type sqlcQuerier interface {
|
|||||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||||
|
UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error
|
||||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||||
|
|||||||
@@ -20731,6 +20731,18 @@ func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, erro
|
|||||||
return template_allowlist, err
|
return template_allowlist, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getChatTitleGenerationModelOverride = `-- name: GetChatTitleGenerationModelOverride :one
|
||||||
|
SELECT
|
||||||
|
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *sqlQuerier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getChatTitleGenerationModelOverride)
|
||||||
|
var model_config_id string
|
||||||
|
err := row.Scan(&model_config_id)
|
||||||
|
return model_config_id, err
|
||||||
|
}
|
||||||
|
|
||||||
const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one
|
const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(
|
COALESCE(
|
||||||
@@ -21085,6 +21097,16 @@ func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAl
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const upsertChatTitleGenerationModelOverride = `-- name: UpsertChatTitleGenerationModelOverride :exec
|
||||||
|
INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1)
|
||||||
|
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override'
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *sqlQuerier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, upsertChatTitleGenerationModelOverride, value)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec
|
const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec
|
||||||
INSERT INTO site_configs (key, value)
|
INSERT INTO site_configs (key, value)
|
||||||
VALUES ('agents_workspace_ttl', $1::text)
|
VALUES ('agents_workspace_ttl', $1::text)
|
||||||
|
|||||||
@@ -183,6 +183,14 @@ SELECT
|
|||||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
|
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
|
||||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override';
|
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override';
|
||||||
|
|
||||||
|
-- name: GetChatTitleGenerationModelOverride :one
|
||||||
|
SELECT
|
||||||
|
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id;
|
||||||
|
|
||||||
|
-- name: UpsertChatTitleGenerationModelOverride :exec
|
||||||
|
INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1)
|
||||||
|
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override';
|
||||||
|
|
||||||
-- name: GetChatDesktopEnabled :one
|
-- name: GetChatDesktopEnabled :one
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
||||||
|
|||||||
+58
-41
@@ -532,62 +532,72 @@ func (api *API) getChatModelOverrideConfig(
|
|||||||
return id, false, nil
|
return id, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) {
|
func parseChatModelOverrideContext(raw string) (codersdk.ChatModelOverrideContext, error) {
|
||||||
overrideContext := codersdk.ChatAgentModelOverrideContext(raw)
|
overrideContext := codersdk.ChatModelOverrideContext(raw)
|
||||||
if overrideContext.Valid() {
|
if overrideContext.Valid() {
|
||||||
return overrideContext, nil
|
return overrideContext, nil
|
||||||
}
|
}
|
||||||
return "", xerrors.Errorf("unknown chat agent model override context %q", raw)
|
return "", xerrors.Errorf("unknown chat model override context %q", raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatAgentModelOverrideSiteConfig struct {
|
type chatModelOverrideSiteConfig struct {
|
||||||
|
label string
|
||||||
getter func(context.Context) (string, error)
|
getter func(context.Context) (string, error)
|
||||||
upsert func(context.Context, string) error
|
upsert func(context.Context, string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) chatAgentModelOverrideSiteConfig(
|
func (api *API) chatModelOverrideSiteConfig(
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) (chatAgentModelOverrideSiteConfig, error) {
|
) (chatModelOverrideSiteConfig, error) {
|
||||||
switch overrideContext {
|
switch overrideContext {
|
||||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
case codersdk.ChatModelOverrideContextGeneral:
|
||||||
return chatAgentModelOverrideSiteConfig{
|
return chatModelOverrideSiteConfig{
|
||||||
|
label: "general",
|
||||||
getter: api.Database.GetChatGeneralModelOverride,
|
getter: api.Database.GetChatGeneralModelOverride,
|
||||||
upsert: api.Database.UpsertChatGeneralModelOverride,
|
upsert: api.Database.UpsertChatGeneralModelOverride,
|
||||||
}, nil
|
}, nil
|
||||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
case codersdk.ChatModelOverrideContextExplore:
|
||||||
return chatAgentModelOverrideSiteConfig{
|
return chatModelOverrideSiteConfig{
|
||||||
|
label: "explore",
|
||||||
getter: api.Database.GetChatExploreModelOverride,
|
getter: api.Database.GetChatExploreModelOverride,
|
||||||
upsert: api.Database.UpsertChatExploreModelOverride,
|
upsert: api.Database.UpsertChatExploreModelOverride,
|
||||||
}, nil
|
}, nil
|
||||||
|
case codersdk.ChatModelOverrideContextTitleGeneration:
|
||||||
|
return chatModelOverrideSiteConfig{
|
||||||
|
label: "title generation",
|
||||||
|
getter: api.Database.GetChatTitleGenerationModelOverride,
|
||||||
|
upsert: api.Database.UpsertChatTitleGenerationModelOverride,
|
||||||
|
}, nil
|
||||||
default:
|
default:
|
||||||
return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf(
|
return chatModelOverrideSiteConfig{}, xerrors.Errorf(
|
||||||
"unknown chat agent model override context %q",
|
"unknown chat model override context %q",
|
||||||
overrideContext,
|
overrideContext,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) getChatAgentModelOverrideConfig(
|
func (api *API) readChatModelOverrideConfig(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) (*uuid.UUID, bool, error) {
|
) (*uuid.UUID, bool, string, error) {
|
||||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, "", err
|
||||||
}
|
}
|
||||||
return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter)
|
id, isMalformed, err := api.getChatModelOverrideConfig(ctx, siteConfig.label, siteConfig.getter)
|
||||||
|
return id, isMalformed, siteConfig.label, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) upsertChatAgentModelOverrideConfig(
|
func (api *API) upsertChatModelOverrideConfig(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
modelConfigID *uuid.UUID,
|
modelConfigID *uuid.UUID,
|
||||||
) error {
|
) (string, error) {
|
||||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
return siteConfig.label, siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||||
@@ -3941,27 +3951,27 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ
|
|||||||
rw.WriteHeader(http.StatusNoContent)
|
rw.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readChatAgentModelOverrideContext(
|
func readChatModelOverrideContext(
|
||||||
rw http.ResponseWriter,
|
rw http.ResponseWriter,
|
||||||
r *http.Request,
|
r *http.Request,
|
||||||
) (codersdk.ChatAgentModelOverrideContext, bool) {
|
) (codersdk.ChatModelOverrideContext, bool) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
rawContext := chi.URLParam(r, "context")
|
rawContext := chi.URLParam(r, "context")
|
||||||
overrideContext, err := parseChatAgentModelOverrideContext(rawContext)
|
overrideContext, err := parseChatModelOverrideContext(rawContext)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return overrideContext, true
|
return overrideContext, true
|
||||||
}
|
}
|
||||||
validContextValues := make(
|
validContextValues := make(
|
||||||
[]string,
|
[]string,
|
||||||
0,
|
0,
|
||||||
len(codersdk.AllChatAgentModelOverrideContexts()),
|
len(codersdk.AllChatModelOverrideContexts()),
|
||||||
)
|
)
|
||||||
for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() {
|
for _, overrideContext := range codersdk.AllChatModelOverrideContexts() {
|
||||||
validContextValues = append(validContextValues, string(overrideContext))
|
validContextValues = append(validContextValues, string(overrideContext))
|
||||||
}
|
}
|
||||||
validContexts := strings.Join(validContextValues, ", ")
|
validContexts := strings.Join(validContextValues, ", ")
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "Invalid chat agent model override context.",
|
Message: "Invalid chat model override context.",
|
||||||
Detail: fmt.Sprintf(
|
Detail: fmt.Sprintf(
|
||||||
"Expected one of %s. Got %q.",
|
"Expected one of %s. Got %q.",
|
||||||
validContexts,
|
validContexts,
|
||||||
@@ -3974,27 +3984,30 @@ func readChatAgentModelOverrideContext(
|
|||||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||||
//
|
//
|
||||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||||
func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) getChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||||
httpapi.ResourceNotFound(rw)
|
httpapi.ResourceNotFound(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext)
|
modelConfigID, isMalformed, label, err := api.readChatModelOverrideConfig(ctx, overrideContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if label == "" {
|
||||||
|
label = string(overrideContext)
|
||||||
|
}
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext),
|
Message: fmt.Sprintf("Internal error fetching %s model override.", label),
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := codersdk.ChatAgentModelOverrideResponse{
|
resp := codersdk.ChatModelOverrideResponse{
|
||||||
Context: overrideContext,
|
Context: overrideContext,
|
||||||
ModelConfigID: formatChatModelOverride(modelConfigID),
|
ModelConfigID: formatChatModelOverride(modelConfigID),
|
||||||
IsMalformed: isMalformed,
|
IsMalformed: isMalformed,
|
||||||
@@ -4004,18 +4017,18 @@ func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||||
func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) putChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||||
httpapi.Forbidden(rw)
|
httpapi.Forbidden(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req codersdk.UpdateChatAgentModelOverrideRequest
|
var req codersdk.UpdateChatModelOverrideRequest
|
||||||
if !httpapi.Read(ctx, rw, r, &req) {
|
if !httpapi.Read(ctx, rw, r, &req) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -4035,9 +4048,13 @@ func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil {
|
label, err := api.upsertChatModelOverrideConfig(ctx, overrideContext, modelConfigID)
|
||||||
|
if err != nil {
|
||||||
|
if label == "" {
|
||||||
|
label = string(overrideContext)
|
||||||
|
}
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext),
|
Message: fmt.Sprintf("Internal error updating %s model override.", label),
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
+27
-17
@@ -10061,28 +10061,28 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
type overrideResponse struct {
|
type overrideResponse struct {
|
||||||
context codersdk.ChatAgentModelOverrideContext
|
context codersdk.ChatModelOverrideContext
|
||||||
modelConfigID string
|
modelConfigID string
|
||||||
isMalformed bool
|
isMalformed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type settingTest struct {
|
type settingTest struct {
|
||||||
name string
|
name string
|
||||||
context codersdk.ChatAgentModelOverrideContext
|
context codersdk.ChatModelOverrideContext
|
||||||
dbGet func(context.Context, database.Store) (string, error)
|
dbGet func(context.Context, database.Store) (string, error)
|
||||||
dbUpsert func(context.Context, database.Store, string) error
|
dbUpsert func(context.Context, database.Store, string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string {
|
settingPath := func(overrideContext codersdk.ChatModelOverrideContext) string {
|
||||||
return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext)
|
return "/api/experimental/chats/config/model-override/" + string(overrideContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
getOverride := func(
|
getOverride := func(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *codersdk.ExperimentalClient,
|
client *codersdk.ExperimentalClient,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) (overrideResponse, error) {
|
) (overrideResponse, error) {
|
||||||
resp, err := client.GetChatAgentModelOverride(ctx, overrideContext)
|
resp, err := client.GetChatModelOverride(ctx, overrideContext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return overrideResponse{}, err
|
return overrideResponse{}, err
|
||||||
}
|
}
|
||||||
@@ -10096,20 +10096,20 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
putOverride := func(
|
putOverride := func(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *codersdk.ExperimentalClient,
|
client *codersdk.ExperimentalClient,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
modelConfigID string,
|
modelConfigID string,
|
||||||
) error {
|
) error {
|
||||||
return client.UpdateChatAgentModelOverride(
|
return client.UpdateChatModelOverride(
|
||||||
ctx,
|
ctx,
|
||||||
overrideContext,
|
overrideContext,
|
||||||
codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID},
|
codersdk.UpdateChatModelOverrideRequest{ModelConfigID: modelConfigID},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := []settingTest{
|
settings := []settingTest{
|
||||||
{
|
{
|
||||||
name: "General",
|
name: "General",
|
||||||
context: codersdk.ChatAgentModelOverrideContextGeneral,
|
context: codersdk.ChatModelOverrideContextGeneral,
|
||||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||||
return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx))
|
return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||||
},
|
},
|
||||||
@@ -10119,7 +10119,7 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Explore",
|
name: "Explore",
|
||||||
context: codersdk.ChatAgentModelOverrideContextExplore,
|
context: codersdk.ChatModelOverrideContextExplore,
|
||||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||||
return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx))
|
return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||||
},
|
},
|
||||||
@@ -10127,6 +10127,16 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "TitleGeneration",
|
||||||
|
context: codersdk.ChatModelOverrideContextTitleGeneration,
|
||||||
|
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||||
|
return db.GetChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||||
|
},
|
||||||
|
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
|
||||||
|
return db.UpsertChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, setting := range settings {
|
for _, setting := range settings {
|
||||||
@@ -10265,23 +10275,23 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
|
|
||||||
adminClient := newChatClient(t)
|
adminClient := newChatClient(t)
|
||||||
coderdtest.CreateFirstUser(t, adminClient.Client)
|
coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
unknownContext := codersdk.ChatModelOverrideContext("not-a-context")
|
||||||
|
|
||||||
_, err := getOverride(ctx, adminClient, unknownContext)
|
_, err := getOverride(ctx, adminClient, unknownContext)
|
||||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
require.Equal(t, "Invalid chat model override context.", sdkErr.Message)
|
||||||
require.Equal(
|
require.Equal(
|
||||||
t,
|
t,
|
||||||
`Expected one of general, explore. Got "not-a-context".`,
|
`Expected one of general, explore, title_generation. Got "not-a-context".`,
|
||||||
sdkErr.Detail,
|
sdkErr.Detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = putOverride(ctx, adminClient, unknownContext, "")
|
err = putOverride(ctx, adminClient, unknownContext, "")
|
||||||
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
|
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
|
||||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
require.Equal(t, "Invalid chat model override context.", sdkErr.Message)
|
||||||
require.Equal(
|
require.Equal(
|
||||||
t,
|
t,
|
||||||
`Expected one of general, explore. Got "not-a-context".`,
|
`Expected one of general, explore, title_generation. Got "not-a-context".`,
|
||||||
sdkErr.Detail,
|
sdkErr.Detail,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@@ -10293,7 +10303,7 @@ func TestChatModelOverrides(t *testing.T) {
|
|||||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
unknownContext := codersdk.ChatModelOverrideContext("not-a-context")
|
||||||
|
|
||||||
_, err := getOverride(ctx, memberClient, unknownContext)
|
_, err := getOverride(ctx, memberClient, unknownContext)
|
||||||
requireSDKError(t, err, http.StatusNotFound)
|
requireSDKError(t, err, http.StatusNotFound)
|
||||||
|
|||||||
@@ -3110,6 +3110,26 @@ func (p *Server) resolveManualTitleModel(
|
|||||||
chat database.Chat,
|
chat database.Chat,
|
||||||
keys chatprovider.ProviderAPIKeys,
|
keys chatprovider.ProviderAPIKeys,
|
||||||
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
||||||
|
overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
keys,
|
||||||
|
)
|
||||||
|
if overrideErr != nil {
|
||||||
|
if overrideSet {
|
||||||
|
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||||
|
"resolve manual title generation model override: %w",
|
||||||
|
overrideErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
p.logger.Debug(ctx, "failed to resolve title generation model override for manual title",
|
||||||
|
slog.F("chat_id", chat.ID),
|
||||||
|
slog.Error(overrideErr),
|
||||||
|
)
|
||||||
|
} else if overrideSet {
|
||||||
|
return overrideModel, overrideConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logger.Debug(ctx, "failed to list manual title model configs",
|
p.logger.Debug(ctx, "failed to list manual title model configs",
|
||||||
|
|||||||
@@ -636,6 +636,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
|||||||
LimitVal: manualTitleMessageWindowLimit,
|
LimitVal: manualTitleMessageWindowLimit,
|
||||||
},
|
},
|
||||||
).Return(nil, nil)
|
).Return(nil, nil)
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
||||||
|
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
@@ -799,6 +800,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
|||||||
LimitVal: manualTitleMessageWindowLimit,
|
LimitVal: manualTitleMessageWindowLimit,
|
||||||
},
|
},
|
||||||
).Return(nil, nil)
|
).Return(nil, nil)
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
||||||
|
|
||||||
gomock.InOrder(
|
gomock.InOrder(
|
||||||
|
|||||||
@@ -106,9 +106,10 @@ type generatedTitle struct {
|
|||||||
// maybeGenerateChatTitle generates an AI title for the chat when
|
// maybeGenerateChatTitle generates an AI title for the chat when
|
||||||
// appropriate (first user message, no assistant reply yet, and the
|
// appropriate (first user message, no assistant reply yet, and the
|
||||||
// current title is either empty or still the fallback truncation).
|
// current title is either empty or still the fallback truncation).
|
||||||
// It tries cheap, fast models first and falls back to the user's
|
// It uses the configured title generation model override when set.
|
||||||
// chat model. It is a best-effort operation that logs and swallows
|
// Otherwise, it tries cheap, fast models first and falls back to the
|
||||||
// errors.
|
// user's chat model. It is a best-effort operation that logs and
|
||||||
|
// swallows errors.
|
||||||
func (p *Server) maybeGenerateChatTitle(
|
func (p *Server) maybeGenerateChatTitle(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
chat database.Chat,
|
chat database.Chat,
|
||||||
@@ -130,9 +131,38 @@ func (p *Server) maybeGenerateChatTitle(
|
|||||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||||
|
titleCtx,
|
||||||
|
chat,
|
||||||
|
keys,
|
||||||
|
)
|
||||||
|
if overrideErr != nil {
|
||||||
|
if overrideSet {
|
||||||
|
logger.Warn(ctx, "title generation model override unavailable, skipping title generation",
|
||||||
|
slog.F("chat_id", chat.ID),
|
||||||
|
slog.F("override_context", titleGenerationOverrideContext),
|
||||||
|
slog.Error(overrideErr),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Debug(ctx, "failed to resolve title generation model override",
|
||||||
|
slog.F("chat_id", chat.ID),
|
||||||
|
slog.F("override_context", titleGenerationOverrideContext),
|
||||||
|
slog.Error(overrideErr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var candidates []shortTextCandidate
|
||||||
|
if overrideSet {
|
||||||
|
candidates = []shortTextCandidate{{
|
||||||
|
provider: overrideConfig.Provider,
|
||||||
|
model: overrideConfig.Model,
|
||||||
|
lm: overrideModel,
|
||||||
|
}}
|
||||||
|
} else {
|
||||||
// Build candidate list: preferred lightweight models first,
|
// Build candidate list: preferred lightweight models first,
|
||||||
// then the user's chat model as last resort.
|
// then the user's chat model as last resort.
|
||||||
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||||
for _, c := range preferredTitleModels {
|
for _, c := range preferredTitleModels {
|
||||||
m, err := chatprovider.ModelFromConfig(
|
m, err := chatprovider.ModelFromConfig(
|
||||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||||
@@ -152,6 +182,7 @@ func (p *Server) maybeGenerateChatTitle(
|
|||||||
model: fallbackModelName,
|
model: fallbackModelName,
|
||||||
lm: fallbackModel,
|
lm: fallbackModel,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
var historyTipMessageID int64
|
var historyTipMessageID int64
|
||||||
if len(messages) > 0 {
|
if len(messages) > 0 {
|
||||||
@@ -197,10 +228,20 @@ func (p *Server) maybeGenerateChatTitle(
|
|||||||
finishDebugRun(err)
|
finishDebugRun(err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
|
if overrideSet {
|
||||||
|
logger.Warn(ctx, "title model candidate failed",
|
||||||
|
slog.F("chat_id", chat.ID),
|
||||||
|
slog.F("override_context", titleGenerationOverrideContext),
|
||||||
|
slog.F("provider", candidate.provider),
|
||||||
|
slog.F("model", candidate.model),
|
||||||
|
slog.Error(err),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
logger.Debug(ctx, "title model candidate failed",
|
logger.Debug(ctx, "title model candidate failed",
|
||||||
slog.F("chat_id", chat.ID),
|
slog.F("chat_id", chat.ID),
|
||||||
slog.Error(err),
|
slog.Error(err),
|
||||||
)
|
)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if title == "" || title == chat.Title {
|
if title == "" || title == chat.Title {
|
||||||
@@ -225,12 +266,20 @@ func (p *Server) maybeGenerateChatTitle(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if lastErr != nil {
|
if lastErr != nil {
|
||||||
|
if overrideSet {
|
||||||
|
logger.Warn(ctx, "all title model candidates failed",
|
||||||
|
slog.F("chat_id", chat.ID),
|
||||||
|
slog.F("override_context", titleGenerationOverrideContext),
|
||||||
|
slog.Error(lastErr),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
logger.Debug(ctx, "all title model candidates failed",
|
logger.Debug(ctx, "all title model candidates failed",
|
||||||
slog.F("chat_id", chat.ID),
|
slog.F("chat_id", chat.ID),
|
||||||
slog.Error(lastErr),
|
slog.Error(lastErr),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newQuickgenDebugModel(
|
func newQuickgenDebugModel(
|
||||||
chat database.Chat,
|
chat database.Chat,
|
||||||
|
|||||||
@@ -104,12 +104,12 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func subagentModelOverrideLogLabel(
|
func subagentModelOverrideLogLabel(
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) string {
|
) string {
|
||||||
switch overrideContext {
|
switch overrideContext {
|
||||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
case codersdk.ChatModelOverrideContextGeneral:
|
||||||
return "general delegated child"
|
return "general delegated child"
|
||||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
case codersdk.ChatModelOverrideContextExplore:
|
||||||
return "explore"
|
return "explore"
|
||||||
default:
|
default:
|
||||||
return string(overrideContext)
|
return string(overrideContext)
|
||||||
@@ -119,16 +119,16 @@ func subagentModelOverrideLogLabel(
|
|||||||
func readSubagentModelOverride(
|
func readSubagentModelOverride(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db database.Store,
|
db database.Store,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
switch overrideContext {
|
switch overrideContext {
|
||||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
case codersdk.ChatModelOverrideContextGeneral:
|
||||||
return db.GetChatGeneralModelOverride(ctx)
|
return db.GetChatGeneralModelOverride(ctx)
|
||||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
case codersdk.ChatModelOverrideContextExplore:
|
||||||
return db.GetChatExploreModelOverride(ctx)
|
return db.GetChatExploreModelOverride(ctx)
|
||||||
default:
|
default:
|
||||||
return "", xerrors.Errorf(
|
return "", xerrors.Errorf(
|
||||||
"unknown subagent model override context %q",
|
"unsupported subagent model override context %q",
|
||||||
overrideContext,
|
overrideContext,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -167,6 +167,20 @@ func enabledProviderContainsName(
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type modelOverrideFailureMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelOverrideFailureModeSoft modelOverrideFailureMode = iota
|
||||||
|
modelOverrideFailureModeHard
|
||||||
|
)
|
||||||
|
|
||||||
|
func modelOverrideErrorLabel(overrideContext string) string {
|
||||||
|
return strings.ReplaceAll(overrideContext, "_", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveConfiguredModelOverride returns ok when a usable override is
|
||||||
|
// resolved. In hard failure mode, ok is also true for configured but unusable
|
||||||
|
// overrides so callers can distinguish them from unset or malformed values.
|
||||||
func (p *Server) resolveConfiguredModelOverride(
|
func (p *Server) resolveConfiguredModelOverride(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
overrideContext string,
|
overrideContext string,
|
||||||
@@ -174,6 +188,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
|||||||
ownerID uuid.UUID,
|
ownerID uuid.UUID,
|
||||||
resolveModelConfig modelOverrideConfigResolver,
|
resolveModelConfig modelOverrideConfigResolver,
|
||||||
resolveProviderKeys modelOverrideProviderKeysResolver,
|
resolveProviderKeys modelOverrideProviderKeysResolver,
|
||||||
|
failureMode modelOverrideFailureMode,
|
||||||
) (database.ChatModelConfig, bool, error) {
|
) (database.ChatModelConfig, bool, error) {
|
||||||
trimmed := strings.TrimSpace(raw)
|
trimmed := strings.TrimSpace(raw)
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
@@ -189,13 +204,40 @@ func (p *Server) resolveConfiguredModelOverride(
|
|||||||
)
|
)
|
||||||
return database.ChatModelConfig{}, false, nil
|
return database.ChatModelConfig{}, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
modelConfig, providerName, err := resolveModelConfig(
|
modelConfig, providerName, err := resolveModelConfig(
|
||||||
ctx,
|
ctx,
|
||||||
configuredModelConfigID,
|
configuredModelConfigID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if failureMode == modelOverrideFailureModeHard {
|
||||||
|
label := modelOverrideErrorLabel(overrideContext)
|
||||||
switch {
|
switch {
|
||||||
case xerrors.Is(err, sql.ErrNoRows):
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||||
|
"%s model override is unavailable: %s",
|
||||||
|
label,
|
||||||
|
configuredModelConfigID,
|
||||||
|
)
|
||||||
|
case errors.Is(err, errInvalidModelOverrideMetadata):
|
||||||
|
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||||
|
"%s model override metadata is invalid for %s: %w",
|
||||||
|
label,
|
||||||
|
configuredModelConfigID,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||||
|
"resolve %s model override %s: %w",
|
||||||
|
label,
|
||||||
|
configuredModelConfigID,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
p.logger.Info(ctx,
|
p.logger.Info(ctx,
|
||||||
"model override is unavailable, ignoring",
|
"model override is unavailable, ignoring",
|
||||||
slog.F("override_context", overrideContext),
|
slog.F("override_context", overrideContext),
|
||||||
@@ -218,6 +260,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
|||||||
}
|
}
|
||||||
return database.ChatModelConfig{}, false, nil
|
return database.ChatModelConfig{}, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
providerKeys, err := resolveProviderKeys(ctx, ownerID)
|
providerKeys, err := resolveProviderKeys(ctx, ownerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return database.ChatModelConfig{}, false, xerrors.Errorf(
|
return database.ChatModelConfig{}, false, xerrors.Errorf(
|
||||||
@@ -228,6 +271,14 @@ func (p *Server) resolveConfiguredModelOverride(
|
|||||||
if providerKeys.APIKey(providerName) == "" &&
|
if providerKeys.APIKey(providerName) == "" &&
|
||||||
!(chatprovider.ProviderAllowsAmbientCredentials(providerName) &&
|
!(chatprovider.ProviderAllowsAmbientCredentials(providerName) &&
|
||||||
providerKeys.HasProvider(providerName)) {
|
providerKeys.HasProvider(providerName)) {
|
||||||
|
if failureMode == modelOverrideFailureModeHard {
|
||||||
|
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||||
|
"%s model override credentials are unavailable for provider %q",
|
||||||
|
modelOverrideErrorLabel(overrideContext),
|
||||||
|
providerName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
p.logger.Info(ctx,
|
p.logger.Info(ctx,
|
||||||
"model override credentials are unavailable, ignoring",
|
"model override credentials are unavailable, ignoring",
|
||||||
slog.F("override_context", overrideContext),
|
slog.F("override_context", overrideContext),
|
||||||
@@ -242,7 +293,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
|||||||
func (p *Server) resolveSubagentModelConfigID(
|
func (p *Server) resolveSubagentModelConfigID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
ownerID uuid.UUID,
|
ownerID uuid.UUID,
|
||||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
overrideContext codersdk.ChatModelOverrideContext,
|
||||||
) (uuid.UUID, error) {
|
) (uuid.UUID, error) {
|
||||||
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
|
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
|
||||||
chatdCtx := dbauthz.AsChatd(ctx)
|
chatdCtx := dbauthz.AsChatd(ctx)
|
||||||
@@ -261,6 +312,7 @@ func (p *Server) resolveSubagentModelConfigID(
|
|||||||
ownerID,
|
ownerID,
|
||||||
p.resolveModelConfigAndNormalizedProvider,
|
p.resolveModelConfigAndNormalizedProvider,
|
||||||
p.resolveUserProviderAPIKeys,
|
p.resolveUserProviderAPIKeys,
|
||||||
|
modelOverrideFailureModeSoft,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return uuid.Nil, err
|
return uuid.Nil, err
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func allSubagentDefinitions() []subagentDefinition {
|
|||||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||||
ctx,
|
ctx,
|
||||||
parent.OwnerID,
|
parent.OwnerID,
|
||||||
codersdk.ChatAgentModelOverrideContextGeneral,
|
codersdk.ChatModelOverrideContextGeneral,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return childSubagentChatOptions{}, err
|
return childSubagentChatOptions{}, err
|
||||||
@@ -67,7 +67,7 @@ func allSubagentDefinitions() []subagentDefinition {
|
|||||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||||
ctx,
|
ctx,
|
||||||
turnParent.OwnerID,
|
turnParent.OwnerID,
|
||||||
codersdk.ChatAgentModelOverrideContextExplore,
|
codersdk.ChatModelOverrideContextExplore,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return childSubagentChatOptions{}, err
|
return childSubagentChatOptions{}, err
|
||||||
|
|||||||
@@ -834,6 +834,7 @@ func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider(
|
|||||||
ByProvider: map[string]string{"bedrock": ""},
|
ByProvider: map[string]string{"bedrock": ""},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
modelOverrideFailureModeSoft,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|||||||
@@ -0,0 +1,100 @@
|
|||||||
|
package chatd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"charm.land/fantasy"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||||
|
)
|
||||||
|
|
||||||
|
const titleGenerationOverrideContext = "title_generation"
|
||||||
|
|
||||||
|
func readTitleGenerationModelOverride(
|
||||||
|
ctx context.Context,
|
||||||
|
db database.Store,
|
||||||
|
) (string, error) {
|
||||||
|
//nolint:gocritic // Chatd is internal, not a user, so this read uses AsChatd.
|
||||||
|
chatdCtx := dbauthz.AsChatd(ctx)
|
||||||
|
raw, err := db.GetChatTitleGenerationModelOverride(chatdCtx)
|
||||||
|
if err != nil {
|
||||||
|
return "", xerrors.Errorf(
|
||||||
|
"get chat title generation model override: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveTitleGenerationModelOverride resolves the deployment-wide title
|
||||||
|
// generation model override. It returns four values:
|
||||||
|
//
|
||||||
|
// - modelConfig and model: populated only on success.
|
||||||
|
// - overrideSet: true when the admin configured a non-empty override,
|
||||||
|
// regardless of whether resolution succeeded. Callers MUST always check
|
||||||
|
// err first; overrideSet alone does not imply the model is usable.
|
||||||
|
// - err: non-nil when resolution failed. DB read failure returns
|
||||||
|
// (zero, nil, false, err). With overrideSet=true, the override is
|
||||||
|
// configured but unusable (deleted model, missing credentials, etc.) and
|
||||||
|
// callers should treat this as a hard failure for explicit-override
|
||||||
|
// semantics, not a soft fallback.
|
||||||
|
//
|
||||||
|
// When the override is unset or stored as malformed, the function returns
|
||||||
|
// (zero, nil, false, nil) so callers can fall back to default behavior.
|
||||||
|
func (p *Server) resolveTitleGenerationModelOverride(
|
||||||
|
ctx context.Context,
|
||||||
|
chat database.Chat,
|
||||||
|
keys chatprovider.ProviderAPIKeys,
|
||||||
|
) (database.ChatModelConfig, fantasy.LanguageModel, bool, error) {
|
||||||
|
raw, err := readTitleGenerationModelOverride(ctx, p.db)
|
||||||
|
if err != nil {
|
||||||
|
return database.ChatModelConfig{}, nil, false, xerrors.Errorf(
|
||||||
|
"read title generation model override: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelConfig, overrideSet, err := p.resolveConfiguredModelOverride(
|
||||||
|
ctx,
|
||||||
|
titleGenerationOverrideContext,
|
||||||
|
raw,
|
||||||
|
chat.OwnerID,
|
||||||
|
p.resolveModelConfigAndNormalizedProvider,
|
||||||
|
func(context.Context, uuid.UUID) (chatprovider.ProviderAPIKeys, error) {
|
||||||
|
return keys, nil
|
||||||
|
},
|
||||||
|
modelOverrideFailureModeHard,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return database.ChatModelConfig{}, nil, overrideSet, err
|
||||||
|
}
|
||||||
|
if !overrideSet {
|
||||||
|
return database.ChatModelConfig{}, nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := chatprovider.ModelFromConfig(
|
||||||
|
modelConfig.Provider,
|
||||||
|
modelConfig.Model,
|
||||||
|
keys,
|
||||||
|
chatprovider.UserAgent(),
|
||||||
|
chatprovider.CoderHeaders(chat),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return database.ChatModelConfig{}, nil, true, xerrors.Errorf(
|
||||||
|
"create title generation model override: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if model == nil {
|
||||||
|
return database.ChatModelConfig{}, nil, true, xerrors.Errorf(
|
||||||
|
"create title generation model override returned nil",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelConfig, model, true, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,559 @@
|
|||||||
|
package chatd //nolint:testpackage // Tests internal title override helpers.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"charm.land/fantasy"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog/v3"
|
||||||
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||||
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||||
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
"github.com/coder/quartz"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("uses preferred model before fallback", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
wantTitle := "Preferred title"
|
||||||
|
|
||||||
|
var requestCount atomic.Int32
|
||||||
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||||
|
requestCount.Add(1)
|
||||||
|
require.Equal(t, preferredTitleModels[1].model, req.Model)
|
||||||
|
return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`)
|
||||||
|
})
|
||||||
|
keys := titleOverrideOpenAIKeys(serverURL)
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
t.Fatal("fallback model should not be called when preferred model works")
|
||||||
|
return nil, xerrors.New("unexpected fallback model call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||||
|
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||||
|
ID: chat.ID,
|
||||||
|
Title: wantTitle,
|
||||||
|
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
keys,
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), requestCount.Load())
|
||||||
|
gotTitle, ok := generated.Load()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, wantTitle, gotTitle)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to chat model when preferred models are unavailable", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
wantTitle := "Fallback title"
|
||||||
|
|
||||||
|
var fallbackCalls atomic.Int32
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
fallbackCalls.Add(1)
|
||||||
|
return &fantasy.ObjectResponse{
|
||||||
|
Object: map[string]any{"title": wantTitle},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||||
|
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||||
|
ID: chat.ID,
|
||||||
|
Title: wantTitle,
|
||||||
|
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
chatprovider.ProviderAPIKeys{},
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||||
|
gotTitle, ok := generated.Load()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, wantTitle, gotTitle)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
wantTitle := "Fallback title"
|
||||||
|
|
||||||
|
var fallbackCalls atomic.Int32
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
fallbackCalls.Add(1)
|
||||||
|
return &fantasy.ObjectResponse{
|
||||||
|
Object: map[string]any{"title": wantTitle},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone)
|
||||||
|
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||||
|
ID: chat.ID,
|
||||||
|
Title: wantTitle,
|
||||||
|
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
chatprovider.ProviderAPIKeys{},
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||||
|
gotTitle, ok := generated.Load()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, wantTitle, gotTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
wantTitle := "Fallback title"
|
||||||
|
|
||||||
|
var fallbackCalls atomic.Int32
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
fallbackCalls.Add(1)
|
||||||
|
return &fantasy.ObjectResponse{
|
||||||
|
Object: map[string]any{"title": wantTitle},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("not-a-uuid", nil)
|
||||||
|
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||||
|
ID: chat.ID,
|
||||||
|
Title: wantTitle,
|
||||||
|
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
chatprovider.ProviderAPIKeys{},
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||||
|
gotTitle, ok := generated.Load()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, wantTitle, gotTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||||
|
wantTitle := "Override title"
|
||||||
|
|
||||||
|
var requestCount atomic.Int32
|
||||||
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||||
|
requestCount.Add(1)
|
||||||
|
require.Equal(t, overrideConfig.Model, req.Model)
|
||||||
|
return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`)
|
||||||
|
})
|
||||||
|
keys := titleOverrideOpenAIKeys(serverURL)
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
t.Fatal("fallback model should not be called when override is usable")
|
||||||
|
return nil, xerrors.New("unexpected fallback model call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||||
|
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||||
|
ID: chat.ID,
|
||||||
|
Title: wantTitle,
|
||||||
|
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
keys,
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), requestCount.Load())
|
||||||
|
gotTitle, ok := generated.Load()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, wantTitle, gotTitle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", false)
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
t.Fatal("fallback model should not be called when override is unusable")
|
||||||
|
return nil, xerrors.New("unexpected fallback model call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
chatprovider.ProviderAPIKeys{},
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, ok := generated.Load()
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||||
|
|
||||||
|
var requestCount atomic.Int32
|
||||||
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||||
|
requestCount.Add(1)
|
||||||
|
require.Equal(t, overrideConfig.Model, req.Model)
|
||||||
|
return chattest.OpenAINonStreamingResponse(`{"title":""}`)
|
||||||
|
})
|
||||||
|
keys := titleOverrideOpenAIKeys(serverURL)
|
||||||
|
fallbackModel := &chattest.FakeModel{
|
||||||
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||||
|
t.Fatal("fallback model should not be called after override call failure")
|
||||||
|
return nil, xerrors.New("unexpected fallback model call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||||
|
|
||||||
|
generated := &generatedChatTitle{}
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
server.maybeGenerateChatTitle(
|
||||||
|
ctx,
|
||||||
|
chat,
|
||||||
|
messages,
|
||||||
|
"openai",
|
||||||
|
"fallback-chat-model",
|
||||||
|
fallbackModel,
|
||||||
|
keys,
|
||||||
|
generated,
|
||||||
|
logger,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Equal(t, int32(1), requestCount.Load())
|
||||||
|
_, ok := generated.Load()
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||||
|
preferredConfig := database.ChatModelConfig{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Provider: preferredTitleModels[1].provider,
|
||||||
|
Model: preferredTitleModels[1].model,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||||
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{
|
||||||
|
{Provider: "openai", Model: "gpt-4.1", Enabled: true},
|
||||||
|
preferredConfig,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
model, gotConfig, err := server.resolveManualTitleModel(
|
||||||
|
ctx,
|
||||||
|
db,
|
||||||
|
chat,
|
||||||
|
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, model)
|
||||||
|
require.Equal(t, preferredConfig, gotConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||||
|
preferredConfig := database.ChatModelConfig{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Provider: preferredTitleModels[1].provider,
|
||||||
|
Model: preferredTitleModels[1].model,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone)
|
||||||
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{
|
||||||
|
{Provider: "openai", Model: "gpt-4.1", Enabled: true},
|
||||||
|
preferredConfig,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
model, gotConfig, err := server.resolveManualTitleModel(
|
||||||
|
ctx,
|
||||||
|
db,
|
||||||
|
chat,
|
||||||
|
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, model)
|
||||||
|
require.Equal(t, preferredConfig, gotConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||||
|
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
model, gotConfig, err := server.resolveManualTitleModel(
|
||||||
|
ctx,
|
||||||
|
db,
|
||||||
|
chat,
|
||||||
|
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, model)
|
||||||
|
require.Equal(t, overrideConfig, gotConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||||
|
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
model, gotConfig, err := server.resolveManualTitleModel(
|
||||||
|
ctx,
|
||||||
|
db,
|
||||||
|
chat,
|
||||||
|
chatprovider.ProviderAPIKeys{},
|
||||||
|
)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||||
|
require.ErrorContains(t, err, "credentials are unavailable")
|
||||||
|
require.Nil(t, model)
|
||||||
|
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
db := dbmock.NewMockStore(ctrl)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||||
|
overrideConfig := titleOverrideModelConfig("gpt-4.1", false)
|
||||||
|
|
||||||
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||||
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||||
|
|
||||||
|
server := titleOverrideTestServer(db, logger)
|
||||||
|
model, gotConfig, err := server.resolveManualTitleModel(
|
||||||
|
ctx,
|
||||||
|
db,
|
||||||
|
chat,
|
||||||
|
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||||
|
)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||||
|
require.ErrorContains(t, err, "title generation model override is unavailable")
|
||||||
|
require.Nil(t, model)
|
||||||
|
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func titleOverrideTestChatAndMessages(t *testing.T) (database.Chat, []database.ChatMessage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
userPrompt := "review pull request 123 and fix comments"
|
||||||
|
chat := database.Chat{
|
||||||
|
ID: uuid.New(),
|
||||||
|
OwnerID: uuid.New(),
|
||||||
|
Title: fallbackChatTitle(userPrompt),
|
||||||
|
}
|
||||||
|
message := mustChatMessage(
|
||||||
|
t,
|
||||||
|
database.ChatMessageRoleUser,
|
||||||
|
database.ChatMessageVisibilityBoth,
|
||||||
|
codersdk.ChatMessageText(userPrompt),
|
||||||
|
)
|
||||||
|
message.ID = 1
|
||||||
|
return chat, []database.ChatMessage{message}
|
||||||
|
}
|
||||||
|
|
||||||
|
func titleOverrideTestServer(db database.Store, logger slog.Logger) *Server {
|
||||||
|
return &Server{
|
||||||
|
db: db,
|
||||||
|
logger: logger,
|
||||||
|
configCache: newChatConfigCache(context.Background(), db, quartz.NewReal()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func titleOverrideModelConfig(model string, enabled bool) database.ChatModelConfig {
|
||||||
|
return database.ChatModelConfig{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Provider: "openai",
|
||||||
|
Model: model,
|
||||||
|
Enabled: enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func titleOverrideOpenAIKeys(serverURL string) chatprovider.ProviderAPIKeys {
|
||||||
|
return chatprovider.ProviderAPIKeys{
|
||||||
|
ByProvider: map[string]string{
|
||||||
|
"openai": "test-key",
|
||||||
|
},
|
||||||
|
BaseURLByProvider: map[string]string{
|
||||||
|
"openai": serverURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatWithTitle(chat database.Chat, title string) database.Chat {
|
||||||
|
chat.Title = title
|
||||||
|
return chat
|
||||||
|
}
|
||||||
+34
-31
@@ -562,45 +562,48 @@ type UpdateChatPlanModeInstructionsRequest struct {
|
|||||||
PlanModeInstructions string `json:"plan_mode_instructions"`
|
PlanModeInstructions string `json:"plan_mode_instructions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatAgentModelOverrideContext identifies which chat or subagent context
|
// ChatModelOverrideContext identifies which chat model override context a
|
||||||
// a deployment override applies to.
|
// deployment override applies to.
|
||||||
type ChatAgentModelOverrideContext string
|
type ChatModelOverrideContext string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general"
|
ChatModelOverrideContextGeneral ChatModelOverrideContext = "general"
|
||||||
ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore"
|
ChatModelOverrideContextExplore ChatModelOverrideContext = "explore"
|
||||||
|
ChatModelOverrideContextTitleGeneration ChatModelOverrideContext = "title_generation"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Valid reports whether the override context is one of the supported values.
|
// Valid reports whether the override context is one of the supported values.
|
||||||
func (c ChatAgentModelOverrideContext) Valid() bool {
|
func (c ChatModelOverrideContext) Valid() bool {
|
||||||
switch c {
|
switch c {
|
||||||
case ChatAgentModelOverrideContextGeneral,
|
case ChatModelOverrideContextGeneral,
|
||||||
ChatAgentModelOverrideContextExplore:
|
ChatModelOverrideContextExplore,
|
||||||
|
ChatModelOverrideContextTitleGeneration:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllChatAgentModelOverrideContexts returns all supported override contexts.
|
// AllChatModelOverrideContexts returns all supported override contexts.
|
||||||
func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext {
|
func AllChatModelOverrideContexts() []ChatModelOverrideContext {
|
||||||
return []ChatAgentModelOverrideContext{
|
return []ChatModelOverrideContext{
|
||||||
ChatAgentModelOverrideContextGeneral,
|
ChatModelOverrideContextGeneral,
|
||||||
ChatAgentModelOverrideContextExplore,
|
ChatModelOverrideContextExplore,
|
||||||
|
ChatModelOverrideContextTitleGeneration,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatAgentModelOverrideResponse is the response body for the chat agent
|
// ChatModelOverrideResponse is the response body for the chat model override
|
||||||
// model override configuration endpoint.
|
// configuration endpoint.
|
||||||
type ChatAgentModelOverrideResponse struct {
|
type ChatModelOverrideResponse struct {
|
||||||
Context ChatAgentModelOverrideContext `json:"context"`
|
Context ChatModelOverrideContext `json:"context"`
|
||||||
ModelConfigID string `json:"model_config_id"`
|
ModelConfigID string `json:"model_config_id"`
|
||||||
IsMalformed bool `json:"is_malformed"`
|
IsMalformed bool `json:"is_malformed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChatAgentModelOverrideRequest is the request body for updating the
|
// UpdateChatModelOverrideRequest is the request body for updating the chat
|
||||||
// chat agent model override configuration endpoint.
|
// model override configuration endpoint.
|
||||||
type UpdateChatAgentModelOverrideRequest struct {
|
type UpdateChatModelOverrideRequest struct {
|
||||||
ModelConfigID string `json:"model_config_id"`
|
ModelConfigID string `json:"model_config_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2098,30 +2101,30 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChatAgentModelOverride returns the deployment-wide chat agent model
|
// GetChatModelOverride returns the deployment-wide chat model override for
|
||||||
// override for the requested context.
|
// the requested context.
|
||||||
func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) {
|
func (c *ExperimentalClient) GetChatModelOverride(ctx context.Context, override ChatModelOverrideContext) (ChatModelOverrideResponse, error) {
|
||||||
path := fmt.Sprintf(
|
path := fmt.Sprintf(
|
||||||
"/api/experimental/chats/config/agent-model-override/%s",
|
"/api/experimental/chats/config/model-override/%s",
|
||||||
url.PathEscape(string(override)),
|
url.PathEscape(string(override)),
|
||||||
)
|
)
|
||||||
res, err := c.Request(ctx, http.MethodGet, path, nil)
|
res, err := c.Request(ctx, http.MethodGet, path, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ChatAgentModelOverrideResponse{}, err
|
return ChatModelOverrideResponse{}, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
if res.StatusCode != http.StatusOK {
|
if res.StatusCode != http.StatusOK {
|
||||||
return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res)
|
return ChatModelOverrideResponse{}, ReadBodyAsError(res)
|
||||||
}
|
}
|
||||||
var resp ChatAgentModelOverrideResponse
|
var resp ChatModelOverrideResponse
|
||||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChatAgentModelOverride updates the deployment-wide chat agent model
|
// UpdateChatModelOverride updates the deployment-wide chat model override for
|
||||||
// override for the requested context.
|
// the requested context.
|
||||||
func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error {
|
func (c *ExperimentalClient) UpdateChatModelOverride(ctx context.Context, override ChatModelOverrideContext, req UpdateChatModelOverrideRequest) error {
|
||||||
path := fmt.Sprintf(
|
path := fmt.Sprintf(
|
||||||
"/api/experimental/chats/config/agent-model-override/%s",
|
"/api/experimental/chats/config/model-override/%s",
|
||||||
url.PathEscape(string(override)),
|
url.PathEscape(string(override)),
|
||||||
)
|
)
|
||||||
res, err := c.Request(ctx, http.MethodPut, path, req)
|
res, err := c.Request(ctx, http.MethodPut, path, req)
|
||||||
|
|||||||
+9
-10
@@ -3262,22 +3262,21 @@ class ExperimentalApiMethods {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
getChatAgentModelOverride = async (
|
getChatModelOverride = async (
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
context: TypesGen.ChatModelOverrideContext,
|
||||||
): Promise<TypesGen.ChatAgentModelOverrideResponse> => {
|
): Promise<TypesGen.ChatModelOverrideResponse> => {
|
||||||
const response =
|
const response = await this.axios.get<TypesGen.ChatModelOverrideResponse>(
|
||||||
await this.axios.get<TypesGen.ChatAgentModelOverrideResponse>(
|
`/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`,
|
||||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
|
||||||
);
|
);
|
||||||
return response.data;
|
return response.data;
|
||||||
};
|
};
|
||||||
|
|
||||||
updateChatAgentModelOverride = async (
|
updateChatModelOverride = async (
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
context: TypesGen.ChatModelOverrideContext,
|
||||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
req: TypesGen.UpdateChatModelOverrideRequest,
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
await this.axios.put(
|
await this.axios.put(
|
||||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
`/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`,
|
||||||
req,
|
req,
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
Generated
+32
-28
@@ -1320,25 +1320,6 @@ export interface Chat {
|
|||||||
readonly children: readonly Chat[];
|
readonly children: readonly Chat[];
|
||||||
}
|
}
|
||||||
|
|
||||||
// From codersdk/chats.go
|
|
||||||
export type ChatAgentModelOverrideContext = "explore" | "general";
|
|
||||||
|
|
||||||
export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [
|
|
||||||
"explore",
|
|
||||||
"general",
|
|
||||||
];
|
|
||||||
|
|
||||||
// From codersdk/chats.go
|
|
||||||
/**
|
|
||||||
* ChatAgentModelOverrideResponse is the response body for the chat agent
|
|
||||||
* model override configuration endpoint.
|
|
||||||
*/
|
|
||||||
export interface ChatAgentModelOverrideResponse {
|
|
||||||
readonly context: ChatAgentModelOverrideContext;
|
|
||||||
readonly model_config_id: string;
|
|
||||||
readonly is_malformed: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
// From codersdk/chats.go
|
// From codersdk/chats.go
|
||||||
/**
|
/**
|
||||||
* ChatAutoArchiveDaysResponse contains the current chat auto-archive setting.
|
* ChatAutoArchiveDaysResponse contains the current chat auto-archive setting.
|
||||||
@@ -2095,6 +2076,29 @@ export interface ChatModelOpenRouterProviderOptions {
|
|||||||
readonly provider?: ChatModelOpenRouterProvider;
|
readonly provider?: ChatModelOpenRouterProvider;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// From codersdk/chats.go
|
||||||
|
export type ChatModelOverrideContext =
|
||||||
|
| "explore"
|
||||||
|
| "general"
|
||||||
|
| "title_generation";
|
||||||
|
|
||||||
|
export const ChatModelOverrideContexts: ChatModelOverrideContext[] = [
|
||||||
|
"explore",
|
||||||
|
"general",
|
||||||
|
"title_generation",
|
||||||
|
];
|
||||||
|
|
||||||
|
// From codersdk/chats.go
|
||||||
|
/**
|
||||||
|
* ChatModelOverrideResponse is the response body for the chat model override
|
||||||
|
* configuration endpoint.
|
||||||
|
*/
|
||||||
|
export interface ChatModelOverrideResponse {
|
||||||
|
readonly context: ChatModelOverrideContext;
|
||||||
|
readonly model_config_id: string;
|
||||||
|
readonly is_malformed: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
// From codersdk/chats.go
|
// From codersdk/chats.go
|
||||||
/**
|
/**
|
||||||
* ChatModelProvider represents provider availability and model results.
|
* ChatModelProvider represents provider availability and model results.
|
||||||
@@ -7804,15 +7808,6 @@ export interface UpdateAppearanceConfig {
|
|||||||
readonly announcement_banners: readonly BannerConfig[];
|
readonly announcement_banners: readonly BannerConfig[];
|
||||||
}
|
}
|
||||||
|
|
||||||
// From codersdk/chats.go
|
|
||||||
/**
|
|
||||||
* UpdateChatAgentModelOverrideRequest is the request body for updating the
|
|
||||||
* chat agent model override configuration endpoint.
|
|
||||||
*/
|
|
||||||
export interface UpdateChatAgentModelOverrideRequest {
|
|
||||||
readonly model_config_id: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
// From codersdk/chats.go
|
// From codersdk/chats.go
|
||||||
/**
|
/**
|
||||||
* UpdateChatAutoArchiveDaysRequest is a request to update the chat
|
* UpdateChatAutoArchiveDaysRequest is a request to update the chat
|
||||||
@@ -7854,6 +7849,15 @@ export interface UpdateChatModelConfigRequest {
|
|||||||
readonly model_config?: ChatModelCallConfig;
|
readonly model_config?: ChatModelCallConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// From codersdk/chats.go
|
||||||
|
/**
|
||||||
|
* UpdateChatModelOverrideRequest is the request body for updating the chat
|
||||||
|
* model override configuration endpoint.
|
||||||
|
*/
|
||||||
|
export interface UpdateChatModelOverrideRequest {
|
||||||
|
readonly model_config_id: string;
|
||||||
|
}
|
||||||
|
|
||||||
// From codersdk/chats.go
|
// From codersdk/chats.go
|
||||||
/**
|
/**
|
||||||
* UpdateChatPlanModeInstructionsRequest is the request body for
|
* UpdateChatPlanModeInstructionsRequest is the request body for
|
||||||
|
|||||||
@@ -12,31 +12,30 @@ import { useAuthenticated } from "#/hooks/useAuthenticated";
|
|||||||
import { RequirePermission } from "#/modules/permissions/RequirePermission";
|
import { RequirePermission } from "#/modules/permissions/RequirePermission";
|
||||||
import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView";
|
import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView";
|
||||||
|
|
||||||
const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
const generalOverrideContext: TypesGen.ChatModelOverrideContext = "general";
|
||||||
"general";
|
const exploreOverrideContext: TypesGen.ChatModelOverrideContext = "explore";
|
||||||
const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
const titleGenerationOverrideContext: TypesGen.ChatModelOverrideContext =
|
||||||
"explore";
|
"title_generation";
|
||||||
|
|
||||||
const chatAgentModelOverrideKey = (
|
const chatModelOverrideKey = (context: TypesGen.ChatModelOverrideContext) =>
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
["chat-model-override", context] as const;
|
||||||
) => ["chat-agent-model-override", context] as const;
|
|
||||||
|
|
||||||
const chatAgentModelOverrideQuery = (
|
const chatModelOverrideQuery = (
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
context: TypesGen.ChatModelOverrideContext,
|
||||||
) => ({
|
) => ({
|
||||||
queryKey: chatAgentModelOverrideKey(context),
|
queryKey: chatModelOverrideKey(context),
|
||||||
queryFn: () => API.experimental.getChatAgentModelOverride(context),
|
queryFn: () => API.experimental.getChatModelOverride(context),
|
||||||
});
|
});
|
||||||
|
|
||||||
const updateChatAgentModelOverrideMutation = (
|
const updateChatModelOverrideMutation = (
|
||||||
queryClient: QueryClient,
|
queryClient: QueryClient,
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
context: TypesGen.ChatModelOverrideContext,
|
||||||
) => ({
|
) => ({
|
||||||
mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) =>
|
mutationFn: (req: TypesGen.UpdateChatModelOverrideRequest) =>
|
||||||
API.experimental.updateChatAgentModelOverride(context, req),
|
API.experimental.updateChatModelOverride(context, req),
|
||||||
onSuccess: async () => {
|
onSuccess: async () => {
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: chatAgentModelOverrideKey(context),
|
queryKey: chatModelOverrideKey(context),
|
||||||
exact: true,
|
exact: true,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
@@ -48,25 +47,36 @@ const AgentSettingsAgentsPage: FC = () => {
|
|||||||
const canEditDeploymentConfig = permissions.editDeploymentConfig;
|
const canEditDeploymentConfig = permissions.editDeploymentConfig;
|
||||||
|
|
||||||
const generalModelOverrideQuery = useQuery({
|
const generalModelOverrideQuery = useQuery({
|
||||||
...chatAgentModelOverrideQuery(generalOverrideContext),
|
...chatModelOverrideQuery(generalOverrideContext),
|
||||||
enabled: canEditDeploymentConfig,
|
enabled: canEditDeploymentConfig,
|
||||||
});
|
});
|
||||||
const exploreModelOverrideQuery = useQuery({
|
const exploreModelOverrideQuery = useQuery({
|
||||||
...chatAgentModelOverrideQuery(exploreOverrideContext),
|
...chatModelOverrideQuery(exploreOverrideContext),
|
||||||
|
enabled: canEditDeploymentConfig,
|
||||||
|
});
|
||||||
|
const titleGenerationModelQuery = useQuery({
|
||||||
|
...chatModelOverrideQuery(titleGenerationOverrideContext),
|
||||||
enabled: canEditDeploymentConfig,
|
enabled: canEditDeploymentConfig,
|
||||||
});
|
});
|
||||||
const modelConfigsQuery = useQuery(chatModelConfigs());
|
const modelConfigsQuery = useQuery(chatModelConfigs());
|
||||||
const saveGeneralModelOverrideMutation = useMutation(
|
const saveGeneralModelOverrideMutation = useMutation(
|
||||||
updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext),
|
updateChatModelOverrideMutation(queryClient, generalOverrideContext),
|
||||||
|
);
|
||||||
|
const saveTitleGenerationModelMutation = useMutation(
|
||||||
|
updateChatModelOverrideMutation(
|
||||||
|
queryClient,
|
||||||
|
titleGenerationOverrideContext,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
const saveExploreModelOverrideMutation = useMutation(
|
const saveExploreModelOverrideMutation = useMutation(
|
||||||
updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext),
|
updateChatModelOverrideMutation(queryClient, exploreOverrideContext),
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RequirePermission isFeatureVisible={canEditDeploymentConfig}>
|
<RequirePermission isFeatureVisible={canEditDeploymentConfig}>
|
||||||
<AgentSettingsAgentsPageView
|
<AgentSettingsAgentsPageView
|
||||||
generalModelOverrideData={generalModelOverrideQuery.data}
|
generalModelOverrideData={generalModelOverrideQuery.data}
|
||||||
|
titleGenerationModelOverrideData={titleGenerationModelQuery.data}
|
||||||
exploreModelOverrideData={exploreModelOverrideQuery.data}
|
exploreModelOverrideData={exploreModelOverrideQuery.data}
|
||||||
modelConfigsData={modelConfigsQuery.data}
|
modelConfigsData={modelConfigsQuery.data}
|
||||||
modelConfigsError={modelConfigsQuery.error}
|
modelConfigsError={modelConfigsQuery.error}
|
||||||
@@ -78,6 +88,13 @@ const AgentSettingsAgentsPage: FC = () => {
|
|||||||
isSaveGeneralModelOverrideError={
|
isSaveGeneralModelOverrideError={
|
||||||
saveGeneralModelOverrideMutation.isError
|
saveGeneralModelOverrideMutation.isError
|
||||||
}
|
}
|
||||||
|
onSaveTitleGenerationModel={saveTitleGenerationModelMutation.mutate}
|
||||||
|
isSavingTitleGenerationModel={
|
||||||
|
saveTitleGenerationModelMutation.isPending
|
||||||
|
}
|
||||||
|
isSaveTitleGenerationModelError={
|
||||||
|
saveTitleGenerationModelMutation.isError
|
||||||
|
}
|
||||||
onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate}
|
onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate}
|
||||||
isSavingExploreModelOverride={
|
isSavingExploreModelOverride={
|
||||||
saveExploreModelOverrideMutation.isPending
|
saveExploreModelOverrideMutation.isPending
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ const OVERRIDE_MALFORMED_WARNING =
|
|||||||
"The saved override is malformed and is being treated as unset. Click Save to clear it.";
|
"The saved override is malformed and is being treated as unset. Click Save to clear it.";
|
||||||
const UNAVAILABLE_SAVED_MODEL_WARNING =
|
const UNAVAILABLE_SAVED_MODEL_WARNING =
|
||||||
"The saved model is no longer enabled and will be ignored until you choose a new override.";
|
"The saved model is no longer enabled and will be ignored until you choose a new override.";
|
||||||
|
const TITLE_UNAVAILABLE_SAVED_MODEL_WARNING =
|
||||||
|
"The selected model is currently unavailable. Title generation will be skipped until you choose another model or clear this setting.";
|
||||||
|
|
||||||
const buildModelConfig = (
|
const buildModelConfig = (
|
||||||
overrides: Partial<TypesGen.ChatModelConfig>,
|
overrides: Partial<TypesGen.ChatModelConfig>,
|
||||||
@@ -28,15 +30,20 @@ const buildModelConfig = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
const buildOverrideData = (
|
const buildOverrideData = (
|
||||||
context: TypesGen.ChatAgentModelOverrideContext,
|
context: TypesGen.ChatModelOverrideContext,
|
||||||
overrides: Partial<TypesGen.ChatAgentModelOverrideResponse> = {},
|
overrides: Partial<TypesGen.ChatModelOverrideResponse> = {},
|
||||||
): TypesGen.ChatAgentModelOverrideResponse => ({
|
): TypesGen.ChatModelOverrideResponse => ({
|
||||||
context,
|
context,
|
||||||
model_config_id: "",
|
model_config_id: "",
|
||||||
is_malformed: false,
|
is_malformed: false,
|
||||||
...overrides,
|
...overrides,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const buildTitleGenerationModelOverrideData = (
|
||||||
|
overrides: Partial<TypesGen.ChatModelOverrideResponse> = {},
|
||||||
|
): TypesGen.ChatModelOverrideResponse =>
|
||||||
|
buildOverrideData("title_generation", overrides);
|
||||||
|
|
||||||
const generalModelConfig = buildModelConfig({
|
const generalModelConfig = buildModelConfig({
|
||||||
id: "model-general-gpt-4.1-mini",
|
id: "model-general-gpt-4.1-mini",
|
||||||
display_name: "GPT 4.1 Mini",
|
display_name: "GPT 4.1 Mini",
|
||||||
@@ -50,6 +57,13 @@ const claudeSonnetModelConfig = buildModelConfig({
|
|||||||
context_limit: 200_000,
|
context_limit: 200_000,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const titleModelConfig = buildModelConfig({
|
||||||
|
id: "model-title-gpt-4o-mini",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
display_name: "GPT 4o Mini",
|
||||||
|
context_limit: 128_000,
|
||||||
|
});
|
||||||
|
|
||||||
const exploreFallbackModelConfig = buildModelConfig({
|
const exploreFallbackModelConfig = buildModelConfig({
|
||||||
id: "model-explore-blank-display",
|
id: "model-explore-blank-display",
|
||||||
provider: "anthropic",
|
provider: "anthropic",
|
||||||
@@ -65,6 +79,14 @@ const generalDisabledModelConfig = buildModelConfig({
|
|||||||
enabled: false,
|
enabled: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const titleDisabledModelConfig = buildModelConfig({
|
||||||
|
id: "model-title-disabled",
|
||||||
|
model: "gpt-4o-mini-legacy",
|
||||||
|
display_name: "GPT 4o Mini Legacy",
|
||||||
|
enabled: false,
|
||||||
|
context_limit: 128_000,
|
||||||
|
});
|
||||||
|
|
||||||
const exploreDisabledModelConfig = buildModelConfig({
|
const exploreDisabledModelConfig = buildModelConfig({
|
||||||
id: "model-explore-disabled",
|
id: "model-explore-disabled",
|
||||||
provider: "anthropic",
|
provider: "anthropic",
|
||||||
@@ -77,8 +99,10 @@ const exploreDisabledModelConfig = buildModelConfig({
|
|||||||
const allModelConfigs: TypesGen.ChatModelConfig[] = [
|
const allModelConfigs: TypesGen.ChatModelConfig[] = [
|
||||||
generalModelConfig,
|
generalModelConfig,
|
||||||
claudeSonnetModelConfig,
|
claudeSonnetModelConfig,
|
||||||
|
titleModelConfig,
|
||||||
exploreFallbackModelConfig,
|
exploreFallbackModelConfig,
|
||||||
generalDisabledModelConfig,
|
generalDisabledModelConfig,
|
||||||
|
titleDisabledModelConfig,
|
||||||
exploreDisabledModelConfig,
|
exploreDisabledModelConfig,
|
||||||
];
|
];
|
||||||
|
|
||||||
@@ -86,6 +110,7 @@ const makeArgs = (
|
|||||||
overrides: Partial<AgentSettingsAgentsPageViewProps> = {},
|
overrides: Partial<AgentSettingsAgentsPageViewProps> = {},
|
||||||
): AgentSettingsAgentsPageViewProps => ({
|
): AgentSettingsAgentsPageViewProps => ({
|
||||||
generalModelOverrideData: buildOverrideData("general"),
|
generalModelOverrideData: buildOverrideData("general"),
|
||||||
|
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData(),
|
||||||
exploreModelOverrideData: buildOverrideData("explore"),
|
exploreModelOverrideData: buildOverrideData("explore"),
|
||||||
modelConfigsData: allModelConfigs,
|
modelConfigsData: allModelConfigs,
|
||||||
modelConfigsError: undefined,
|
modelConfigsError: undefined,
|
||||||
@@ -93,6 +118,9 @@ const makeArgs = (
|
|||||||
onSaveGeneralModelOverride: fn(),
|
onSaveGeneralModelOverride: fn(),
|
||||||
isSavingGeneralModelOverride: false,
|
isSavingGeneralModelOverride: false,
|
||||||
isSaveGeneralModelOverrideError: false,
|
isSaveGeneralModelOverrideError: false,
|
||||||
|
onSaveTitleGenerationModel: fn(),
|
||||||
|
isSavingTitleGenerationModel: false,
|
||||||
|
isSaveTitleGenerationModelError: false,
|
||||||
onSaveExploreModelOverride: fn(),
|
onSaveExploreModelOverride: fn(),
|
||||||
isSavingExploreModelOverride: false,
|
isSavingExploreModelOverride: false,
|
||||||
isSaveExploreModelOverrideError: false,
|
isSaveExploreModelOverrideError: false,
|
||||||
@@ -146,13 +174,28 @@ export const AllOverridesUnset: Story = {
|
|||||||
const headings = await canvas.findAllByRole("heading", { level: 3 });
|
const headings = await canvas.findAllByRole("heading", { level: 3 });
|
||||||
expect(headings.map((heading) => heading.textContent?.trim())).toEqual([
|
expect(headings.map((heading) => heading.textContent?.trim())).toEqual([
|
||||||
"General model",
|
"General model",
|
||||||
|
"Title generation model",
|
||||||
"Explore subagent model",
|
"Explore subagent model",
|
||||||
]);
|
]);
|
||||||
|
await canvas.findByText(
|
||||||
|
"Choose a model for generated chat titles. Leave unset to use Coder's default title algorithm, which currently tries fast title models for configured providers first, for example Claude Haiku, GPT-4o mini, and Gemini Flash, then falls back to the chat's current model. When a model is selected here, Coder uses only that model for title generation. Recommended title models are fast and low cost.",
|
||||||
|
);
|
||||||
|
|
||||||
for (const headingName of ["General model", "Explore subagent model"]) {
|
const unsetSections = [
|
||||||
|
{ headingName: "General model", placeholder: "Use chat default" },
|
||||||
|
{
|
||||||
|
headingName: "Title generation model",
|
||||||
|
placeholder: "Use title default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headingName: "Explore subagent model",
|
||||||
|
placeholder: "Use chat default",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
for (const { headingName, placeholder } of unsetSections) {
|
||||||
const section = await getSection(canvasElement, headingName);
|
const section = await getSection(canvasElement, headingName);
|
||||||
expect(
|
expect(
|
||||||
within(section).getByRole("combobox", { name: "Use chat default" }),
|
within(section).getByRole("combobox", { name: placeholder }),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
expect(
|
expect(
|
||||||
within(section).getByRole("button", { name: "Save" }),
|
within(section).getByRole("button", { name: "Save" }),
|
||||||
@@ -166,12 +209,19 @@ export const EachOverrideSetToEnabledModel: Story = {
|
|||||||
generalModelOverrideData: buildOverrideData("general", {
|
generalModelOverrideData: buildOverrideData("general", {
|
||||||
model_config_id: generalModelConfig.id,
|
model_config_id: generalModelConfig.id,
|
||||||
}),
|
}),
|
||||||
|
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||||
|
model_config_id: titleModelConfig.id,
|
||||||
|
}),
|
||||||
exploreModelOverrideData: buildOverrideData("explore", {
|
exploreModelOverrideData: buildOverrideData("explore", {
|
||||||
model_config_id: exploreFallbackModelConfig.id,
|
model_config_id: exploreFallbackModelConfig.id,
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
play: async ({ canvasElement, args }) => {
|
play: async ({ canvasElement, args }) => {
|
||||||
const generalSection = await getSection(canvasElement, "General model");
|
const generalSection = await getSection(canvasElement, "General model");
|
||||||
|
const titleSection = await getSection(
|
||||||
|
canvasElement,
|
||||||
|
"Title generation model",
|
||||||
|
);
|
||||||
const exploreSection = await getSection(
|
const exploreSection = await getSection(
|
||||||
canvasElement,
|
canvasElement,
|
||||||
"Explore subagent model",
|
"Explore subagent model",
|
||||||
@@ -183,6 +233,12 @@ export const EachOverrideSetToEnabledModel: Story = {
|
|||||||
}),
|
}),
|
||||||
).toHaveTextContent("claude-sonnet-4-20250514");
|
).toHaveTextContent("claude-sonnet-4-20250514");
|
||||||
|
|
||||||
|
expect(
|
||||||
|
within(titleSection).getByRole("combobox", {
|
||||||
|
name: /gpt 4o mini/i,
|
||||||
|
}),
|
||||||
|
).toHaveTextContent("GPT 4o Mini");
|
||||||
|
|
||||||
await selectModelInSection(
|
await selectModelInSection(
|
||||||
generalSection,
|
generalSection,
|
||||||
canvasElement,
|
canvasElement,
|
||||||
@@ -203,6 +259,26 @@ export const EachOverrideSetToEnabledModel: Story = {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await selectModelInSection(
|
||||||
|
titleSection,
|
||||||
|
canvasElement,
|
||||||
|
/gpt 4o mini/i,
|
||||||
|
"Claude Sonnet 4",
|
||||||
|
);
|
||||||
|
const titleSaveButton = within(titleSection).getByRole("button", {
|
||||||
|
name: "Save",
|
||||||
|
});
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(titleSaveButton).toBeEnabled();
|
||||||
|
});
|
||||||
|
await userEvent.click(titleSaveButton);
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith(
|
||||||
|
{ model_config_id: claudeSonnetModelConfig.id },
|
||||||
|
expect.anything(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
const exploreClearButton = within(exploreSection).getByRole("button", {
|
const exploreClearButton = within(exploreSection).getByRole("button", {
|
||||||
name: "Clear",
|
name: "Clear",
|
||||||
});
|
});
|
||||||
@@ -228,18 +304,25 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = {
|
|||||||
generalModelOverrideData: buildOverrideData("general", {
|
generalModelOverrideData: buildOverrideData("general", {
|
||||||
is_malformed: true,
|
is_malformed: true,
|
||||||
}),
|
}),
|
||||||
|
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||||
|
is_malformed: true,
|
||||||
|
}),
|
||||||
exploreModelOverrideData: buildOverrideData("explore", {
|
exploreModelOverrideData: buildOverrideData("explore", {
|
||||||
is_malformed: true,
|
is_malformed: true,
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
play: async ({ canvasElement, args }) => {
|
play: async ({ canvasElement, args }) => {
|
||||||
const generalSection = await getSection(canvasElement, "General model");
|
const generalSection = await getSection(canvasElement, "General model");
|
||||||
|
const titleSection = await getSection(
|
||||||
|
canvasElement,
|
||||||
|
"Title generation model",
|
||||||
|
);
|
||||||
const exploreSection = await getSection(
|
const exploreSection = await getSection(
|
||||||
canvasElement,
|
canvasElement,
|
||||||
"Explore subagent model",
|
"Explore subagent model",
|
||||||
);
|
);
|
||||||
|
|
||||||
for (const section of [generalSection, exploreSection]) {
|
for (const section of [generalSection, titleSection, exploreSection]) {
|
||||||
await within(section).findByText(OVERRIDE_MALFORMED_WARNING);
|
await within(section).findByText(OVERRIDE_MALFORMED_WARNING);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,6 +340,20 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const titleSaveButton = within(titleSection).getByRole("button", {
|
||||||
|
name: "Save",
|
||||||
|
});
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(titleSaveButton).toBeEnabled();
|
||||||
|
});
|
||||||
|
await userEvent.click(titleSaveButton);
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith(
|
||||||
|
{ model_config_id: "" },
|
||||||
|
expect.anything(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
const exploreSaveButton = within(exploreSection).getByRole("button", {
|
const exploreSaveButton = within(exploreSection).getByRole("button", {
|
||||||
name: "Save",
|
name: "Save",
|
||||||
});
|
});
|
||||||
@@ -278,12 +375,19 @@ export const UnavailableSavedModels: Story = {
|
|||||||
generalModelOverrideData: buildOverrideData("general", {
|
generalModelOverrideData: buildOverrideData("general", {
|
||||||
model_config_id: generalDisabledModelConfig.id,
|
model_config_id: generalDisabledModelConfig.id,
|
||||||
}),
|
}),
|
||||||
|
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||||
|
model_config_id: titleDisabledModelConfig.id,
|
||||||
|
}),
|
||||||
exploreModelOverrideData: buildOverrideData("explore", {
|
exploreModelOverrideData: buildOverrideData("explore", {
|
||||||
model_config_id: exploreDisabledModelConfig.id,
|
model_config_id: exploreDisabledModelConfig.id,
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
play: async ({ canvasElement }) => {
|
play: async ({ canvasElement }) => {
|
||||||
const generalSection = await getSection(canvasElement, "General model");
|
const generalSection = await getSection(canvasElement, "General model");
|
||||||
|
const titleSection = await getSection(
|
||||||
|
canvasElement,
|
||||||
|
"Title generation model",
|
||||||
|
);
|
||||||
const exploreSection = await getSection(
|
const exploreSection = await getSection(
|
||||||
canvasElement,
|
canvasElement,
|
||||||
"Explore subagent model",
|
"Explore subagent model",
|
||||||
@@ -295,5 +399,13 @@ export const UnavailableSavedModels: Story = {
|
|||||||
within(section).getByRole("combobox", { name: "Unavailable model" }),
|
within(section).getByRole("combobox", { name: "Unavailable model" }),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
}
|
}
|
||||||
|
await within(titleSection).findByText(
|
||||||
|
TITLE_UNAVAILABLE_SAVED_MODEL_WARNING,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
within(titleSection).getByRole("combobox", {
|
||||||
|
name: "Unavailable model",
|
||||||
|
}),
|
||||||
|
).toBeInTheDocument();
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -7,19 +7,23 @@ import {
|
|||||||
} from "./components/SubagentModelOverrideSettings";
|
} from "./components/SubagentModelOverrideSettings";
|
||||||
|
|
||||||
type SaveModelOverride = (
|
type SaveModelOverride = (
|
||||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
req: { readonly model_config_id: string },
|
||||||
options?: MutationCallbacks,
|
options?: MutationCallbacks,
|
||||||
) => void;
|
) => void;
|
||||||
|
|
||||||
export interface AgentSettingsAgentsPageViewProps {
|
export interface AgentSettingsAgentsPageViewProps {
|
||||||
generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
generalModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||||
exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
titleGenerationModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||||
|
exploreModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||||
modelConfigsData: TypesGen.ChatModelConfig[] | undefined;
|
modelConfigsData: TypesGen.ChatModelConfig[] | undefined;
|
||||||
modelConfigsError: unknown;
|
modelConfigsError: unknown;
|
||||||
isLoadingModelConfigs: boolean;
|
isLoadingModelConfigs: boolean;
|
||||||
onSaveGeneralModelOverride?: SaveModelOverride;
|
onSaveGeneralModelOverride?: SaveModelOverride;
|
||||||
isSavingGeneralModelOverride?: boolean;
|
isSavingGeneralModelOverride?: boolean;
|
||||||
isSaveGeneralModelOverrideError?: boolean;
|
isSaveGeneralModelOverrideError?: boolean;
|
||||||
|
onSaveTitleGenerationModel: SaveModelOverride;
|
||||||
|
isSavingTitleGenerationModel: boolean;
|
||||||
|
isSaveTitleGenerationModelError: boolean;
|
||||||
onSaveExploreModelOverride: SaveModelOverride;
|
onSaveExploreModelOverride: SaveModelOverride;
|
||||||
isSavingExploreModelOverride: boolean;
|
isSavingExploreModelOverride: boolean;
|
||||||
isSaveExploreModelOverrideError: boolean;
|
isSaveExploreModelOverrideError: boolean;
|
||||||
@@ -29,6 +33,7 @@ export const AgentSettingsAgentsPageView: FC<
|
|||||||
AgentSettingsAgentsPageViewProps
|
AgentSettingsAgentsPageViewProps
|
||||||
> = ({
|
> = ({
|
||||||
generalModelOverrideData,
|
generalModelOverrideData,
|
||||||
|
titleGenerationModelOverrideData,
|
||||||
exploreModelOverrideData,
|
exploreModelOverrideData,
|
||||||
modelConfigsData,
|
modelConfigsData,
|
||||||
modelConfigsError,
|
modelConfigsError,
|
||||||
@@ -36,6 +41,9 @@ export const AgentSettingsAgentsPageView: FC<
|
|||||||
onSaveGeneralModelOverride,
|
onSaveGeneralModelOverride,
|
||||||
isSavingGeneralModelOverride = false,
|
isSavingGeneralModelOverride = false,
|
||||||
isSaveGeneralModelOverrideError = false,
|
isSaveGeneralModelOverrideError = false,
|
||||||
|
onSaveTitleGenerationModel,
|
||||||
|
isSavingTitleGenerationModel,
|
||||||
|
isSaveTitleGenerationModelError,
|
||||||
onSaveExploreModelOverride,
|
onSaveExploreModelOverride,
|
||||||
isSavingExploreModelOverride,
|
isSavingExploreModelOverride,
|
||||||
isSaveExploreModelOverrideError,
|
isSaveExploreModelOverrideError,
|
||||||
@@ -77,6 +85,31 @@ export const AgentSettingsAgentsPageView: FC<
|
|||||||
/>
|
/>
|
||||||
</section>
|
</section>
|
||||||
)}
|
)}
|
||||||
|
<section
|
||||||
|
aria-label="Title generation model"
|
||||||
|
className="flex flex-col gap-3"
|
||||||
|
>
|
||||||
|
<SectionHeader
|
||||||
|
label="Title generation model"
|
||||||
|
description="Choose a model for generated chat titles. Leave unset to use Coder's default title algorithm, which currently tries fast title models for configured providers first, for example Claude Haiku, GPT-4o mini, and Gemini Flash, then falls back to the chat's current model. When a model is selected here, Coder uses only that model for title generation. Recommended title models are fast and low cost."
|
||||||
|
level="section"
|
||||||
|
/>
|
||||||
|
<SubagentModelOverrideSettings
|
||||||
|
title="Title generation model"
|
||||||
|
description="Choose a model for generated chat titles."
|
||||||
|
modelOverrideData={titleGenerationModelOverrideData}
|
||||||
|
enabledModelConfigs={enabledModelConfigs}
|
||||||
|
modelConfigsError={modelConfigsError}
|
||||||
|
isLoading={isLoadingModelConfigs}
|
||||||
|
onSaveModelOverride={onSaveTitleGenerationModel}
|
||||||
|
isSaving={isSavingTitleGenerationModel}
|
||||||
|
isSaveError={isSaveTitleGenerationModelError}
|
||||||
|
saveErrorMessage="Failed to save title generation model."
|
||||||
|
unsetPlaceholder="Use title default"
|
||||||
|
unavailableModelWarning="The selected model is currently unavailable. Title generation will be skipped until you choose another model or clear this setting."
|
||||||
|
showHeader={false}
|
||||||
|
/>
|
||||||
|
</section>
|
||||||
<section
|
<section
|
||||||
aria-label="Explore subagent model"
|
aria-label="Explore subagent model"
|
||||||
className="flex flex-col gap-3"
|
className="flex flex-col gap-3"
|
||||||
|
|||||||
@@ -161,9 +161,17 @@ const AgentsRouteElement = () => (
|
|||||||
model_config_id: "",
|
model_config_id: "",
|
||||||
is_malformed: false,
|
is_malformed: false,
|
||||||
}}
|
}}
|
||||||
|
titleGenerationModelOverrideData={{
|
||||||
|
context: "title_generation",
|
||||||
|
model_config_id: "",
|
||||||
|
is_malformed: false,
|
||||||
|
}}
|
||||||
modelConfigsData={[]}
|
modelConfigsData={[]}
|
||||||
modelConfigsError={undefined}
|
modelConfigsError={undefined}
|
||||||
isLoadingModelConfigs={false}
|
isLoadingModelConfigs={false}
|
||||||
|
onSaveTitleGenerationModel={fn()}
|
||||||
|
isSavingTitleGenerationModel={false}
|
||||||
|
isSaveTitleGenerationModelError={false}
|
||||||
onSaveExploreModelOverride={fn()}
|
onSaveExploreModelOverride={fn()}
|
||||||
isSavingExploreModelOverride={false}
|
isSavingExploreModelOverride={false}
|
||||||
isSaveExploreModelOverrideError={false}
|
isSaveExploreModelOverrideError={false}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ interface UpdateModelOverrideRequest {
|
|||||||
|
|
||||||
interface SubagentModelOverrideSettingsProps {
|
interface SubagentModelOverrideSettingsProps {
|
||||||
title: string;
|
title: string;
|
||||||
description: ReactNode;
|
description?: ReactNode;
|
||||||
modelOverrideData: ModelOverrideData | undefined;
|
modelOverrideData: ModelOverrideData | undefined;
|
||||||
enabledModelConfigs: readonly TypesGen.ChatModelConfig[];
|
enabledModelConfigs: readonly TypesGen.ChatModelConfig[];
|
||||||
modelConfigsError: unknown;
|
modelConfigsError: unknown;
|
||||||
@@ -34,6 +34,8 @@ interface SubagentModelOverrideSettingsProps {
|
|||||||
isSaving: boolean;
|
isSaving: boolean;
|
||||||
isSaveError: boolean;
|
isSaveError: boolean;
|
||||||
saveErrorMessage: string;
|
saveErrorMessage: string;
|
||||||
|
unsetPlaceholder?: string;
|
||||||
|
unavailableModelWarning?: string;
|
||||||
showHeader?: boolean;
|
showHeader?: boolean;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
@@ -61,6 +63,8 @@ export const SubagentModelOverrideSettings: FC<
|
|||||||
isSaving,
|
isSaving,
|
||||||
isSaveError,
|
isSaveError,
|
||||||
saveErrorMessage,
|
saveErrorMessage,
|
||||||
|
unsetPlaceholder = "Use chat default",
|
||||||
|
unavailableModelWarning = "The saved model is no longer enabled and will be ignored until you choose a new override.",
|
||||||
showHeader = true,
|
showHeader = true,
|
||||||
disabled = false,
|
disabled = false,
|
||||||
}) => {
|
}) => {
|
||||||
@@ -104,9 +108,11 @@ export const SubagentModelOverrideSettings: FC<
|
|||||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||||
{title}
|
{title}
|
||||||
</h3>
|
</h3>
|
||||||
|
{description && (
|
||||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||||
{description}
|
{description}
|
||||||
</p>
|
</p>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<ModelSelector
|
<ModelSelector
|
||||||
@@ -115,7 +121,7 @@ export const SubagentModelOverrideSettings: FC<
|
|||||||
onValueChange={(value) => form.setFieldValue("model_config_id", value)}
|
onValueChange={(value) => form.setFieldValue("model_config_id", value)}
|
||||||
disabled={isModelOverrideDisabled}
|
disabled={isModelOverrideDisabled}
|
||||||
placeholder={
|
placeholder={
|
||||||
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
|
isUnavailableSavedModel ? "Unavailable model" : unsetPlaceholder
|
||||||
}
|
}
|
||||||
emptyMessage={
|
emptyMessage={
|
||||||
isLoading ? "Loading models..." : "No enabled models found."
|
isLoading ? "Loading models..." : "No enabled models found."
|
||||||
@@ -125,10 +131,7 @@ export const SubagentModelOverrideSettings: FC<
|
|||||||
/>
|
/>
|
||||||
{isUnavailableSavedModel && (
|
{isUnavailableSavedModel && (
|
||||||
<Alert severity="warning">
|
<Alert severity="warning">
|
||||||
<AlertDescription>
|
<AlertDescription>{unavailableModelWarning}</AlertDescription>
|
||||||
The saved model is no longer enabled and will be ignored until you
|
|
||||||
choose a new override.
|
|
||||||
</AlertDescription>
|
|
||||||
</Alert>
|
</Alert>
|
||||||
)}
|
)}
|
||||||
{isMalformedOverride && (
|
{isMalformedOverride && (
|
||||||
|
|||||||
Reference in New Issue
Block a user