diff --git a/coderd/coderd.go b/coderd/coderd.go index f2410cfff2..1ff2a5ed48 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1192,8 +1192,8 @@ func New(options *Options) *API { r.Put("/system-prompt", api.putChatSystemPrompt) r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions) r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions) - r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride) - r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride) + r.Get("/model-override/{context}", api.getChatModelOverride) + r.Put("/model-override/{context}", api.putChatModelOverride) r.Get("/desktop-enabled", api.getChatDesktopEnabled) r.Put("/desktop-enabled", api.putChatDesktopEnabled) r.Get("/debug-logging", api.getChatDebugLogging) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a724c5959d..0ab0b52123 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2967,6 +2967,13 @@ func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error) return q.db.GetChatTemplateAllowlist(ctx) } +func (q *querier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return "", err + } + return q.db.GetChatTitleGenerationModelOverride(ctx) +} + func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return database.ChatUsageLimitConfig{}, err @@ -7517,6 +7524,13 @@ func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllow return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist) } +func (q *querier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatTitleGenerationModelOverride(ctx, value) +} + func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return database.ChatUsageLimitConfig{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index db39439930..2d6b189c8c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -918,6 +918,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) })) + s.Run("GetChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) + })) s.Run("GetChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatPlanModeInstructions(gomock.Any()).Return("", nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) @@ -1237,6 +1241,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("UpsertChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatTitleGenerationModelOverride(gomock.Any(), "").Return(nil).AnyTimes() + check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("UpsertChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatPlanModeInstructions(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 35b5f13815..f9b7c9651b 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1456,6 +1456,14 @@ func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string return r0, r1 } +func (m queryMetricsStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatTitleGenerationModelOverride(ctx) + m.queryLatencies.WithLabelValues("GetChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTitleGenerationModelOverride").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { start := time.Now() r0, r1 := m.s.GetChatUsageLimitConfig(ctx) @@ -5408,6 +5416,14 @@ func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, temp return r0 } +func (m queryMetricsStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error { + start := time.Now() + r0 := m.s.UpsertChatTitleGenerationModelOverride(ctx, value) + m.queryLatencies.WithLabelValues("UpsertChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTitleGenerationModelOverride").Inc() + return r0 +} + func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { start := time.Now() r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 45c0d4a97c..625c5a53cf 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2687,6 +2687,21 @@ func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx) } +// GetChatTitleGenerationModelOverride mocks base method. +func (m *MockStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatTitleGenerationModelOverride", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatTitleGenerationModelOverride indicates an expected call of GetChatTitleGenerationModelOverride. +func (mr *MockStoreMockRecorder) GetChatTitleGenerationModelOverride(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatTitleGenerationModelOverride), ctx) +} + // GetChatUsageLimitConfig mocks base method. func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { m.ctrl.T.Helper() @@ -10152,6 +10167,20 @@ func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowl return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist) } +// UpsertChatTitleGenerationModelOverride mocks base method. +func (m *MockStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatTitleGenerationModelOverride", ctx, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatTitleGenerationModelOverride indicates an expected call of UpsertChatTitleGenerationModelOverride. +func (mr *MockStoreMockRecorder) UpsertChatTitleGenerationModelOverride(ctx, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatTitleGenerationModelOverride), ctx, value) +} + // UpsertChatUsageLimitConfig mocks base method. func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 85749324c2..6a8cedd3ab 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -362,6 +362,7 @@ type sqlcQuerier interface { // GetChatTemplateAllowlist returns the JSON-encoded template allowlist. // Returns an empty string when no allowlist has been configured (all templates allowed). GetChatTemplateAllowlist(ctx context.Context) (string, error) + GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error) @@ -1206,6 +1207,7 @@ type sqlcQuerier interface { UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error UpsertChatSystemPrompt(ctx context.Context, value string) error UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error + UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error) UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ac061d90b5..ee331ee2f6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -20731,6 +20731,18 @@ func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, erro return template_allowlist, err } +const getChatTitleGenerationModelOverride = `-- name: GetChatTitleGenerationModelOverride :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id +` + +func (q *sqlQuerier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatTitleGenerationModelOverride) + var model_config_id string + err := row.Scan(&model_config_id) + return model_config_id, err +} + const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one SELECT COALESCE( @@ -21085,6 +21097,16 @@ func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAl return err } +const upsertChatTitleGenerationModelOverride = `-- name: UpsertChatTitleGenerationModelOverride :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override' +` + +func (q *sqlQuerier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, upsertChatTitleGenerationModelOverride, value) + return err +} + const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec INSERT INTO site_configs (key, value) VALUES ('agents_workspace_ttl', $1::text) diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 2001b910e3..5c6e591023 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -183,6 +183,14 @@ SELECT INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1) ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override'; +-- name: GetChatTitleGenerationModelOverride :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id; + +-- name: UpsertChatTitleGenerationModelOverride :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override'; + -- name: GetChatDesktopEnabled :one SELECT COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop; diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index cca6917e7d..9892f36ad4 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -532,62 +532,72 @@ func (api *API) getChatModelOverrideConfig( return id, false, nil } -func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) { - overrideContext := codersdk.ChatAgentModelOverrideContext(raw) +func parseChatModelOverrideContext(raw string) (codersdk.ChatModelOverrideContext, error) { + overrideContext := codersdk.ChatModelOverrideContext(raw) if overrideContext.Valid() { return overrideContext, nil } - return "", xerrors.Errorf("unknown chat agent model override context %q", raw) + return "", xerrors.Errorf("unknown chat model override context %q", raw) } -type chatAgentModelOverrideSiteConfig struct { +type chatModelOverrideSiteConfig struct { + label string getter func(context.Context) (string, error) upsert func(context.Context, string) error } -func (api *API) chatAgentModelOverrideSiteConfig( - overrideContext codersdk.ChatAgentModelOverrideContext, -) (chatAgentModelOverrideSiteConfig, error) { +func (api *API) chatModelOverrideSiteConfig( + overrideContext codersdk.ChatModelOverrideContext, +) (chatModelOverrideSiteConfig, error) { switch overrideContext { - case codersdk.ChatAgentModelOverrideContextGeneral: - return chatAgentModelOverrideSiteConfig{ + case codersdk.ChatModelOverrideContextGeneral: + return chatModelOverrideSiteConfig{ + label: "general", getter: api.Database.GetChatGeneralModelOverride, upsert: api.Database.UpsertChatGeneralModelOverride, }, nil - case codersdk.ChatAgentModelOverrideContextExplore: - return chatAgentModelOverrideSiteConfig{ + case codersdk.ChatModelOverrideContextExplore: + return chatModelOverrideSiteConfig{ + label: "explore", getter: api.Database.GetChatExploreModelOverride, upsert: api.Database.UpsertChatExploreModelOverride, }, nil + case codersdk.ChatModelOverrideContextTitleGeneration: + return chatModelOverrideSiteConfig{ + label: "title generation", + getter: api.Database.GetChatTitleGenerationModelOverride, + upsert: api.Database.UpsertChatTitleGenerationModelOverride, + }, nil default: - return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf( - "unknown chat agent model override context %q", + return chatModelOverrideSiteConfig{}, xerrors.Errorf( + "unknown chat model override context %q", overrideContext, ) } } -func (api *API) getChatAgentModelOverrideConfig( +func (api *API) readChatModelOverrideConfig( ctx context.Context, - overrideContext codersdk.ChatAgentModelOverrideContext, -) (*uuid.UUID, bool, error) { - siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext) + overrideContext codersdk.ChatModelOverrideContext, +) (*uuid.UUID, bool, string, error) { + siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext) if err != nil { - return nil, false, err + return nil, false, "", err } - return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter) + id, isMalformed, err := api.getChatModelOverrideConfig(ctx, siteConfig.label, siteConfig.getter) + return id, isMalformed, siteConfig.label, err } -func (api *API) upsertChatAgentModelOverrideConfig( +func (api *API) upsertChatModelOverrideConfig( ctx context.Context, - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, modelConfigID *uuid.UUID, -) error { - siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext) +) (string, error) { + siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext) if err != nil { - return err + return "", err } - return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID)) + return siteConfig.label, siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID)) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -3941,27 +3951,27 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ rw.WriteHeader(http.StatusNoContent) } -func readChatAgentModelOverrideContext( +func readChatModelOverrideContext( rw http.ResponseWriter, r *http.Request, -) (codersdk.ChatAgentModelOverrideContext, bool) { +) (codersdk.ChatModelOverrideContext, bool) { ctx := r.Context() rawContext := chi.URLParam(r, "context") - overrideContext, err := parseChatAgentModelOverrideContext(rawContext) + overrideContext, err := parseChatModelOverrideContext(rawContext) if err == nil { return overrideContext, true } validContextValues := make( []string, 0, - len(codersdk.AllChatAgentModelOverrideContexts()), + len(codersdk.AllChatModelOverrideContexts()), ) - for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() { + for _, overrideContext := range codersdk.AllChatModelOverrideContexts() { validContextValues = append(validContextValues, string(overrideContext)) } validContexts := strings.Join(validContextValues, ", ") httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid chat agent model override context.", + Message: "Invalid chat model override context.", Detail: fmt.Sprintf( "Expected one of %s. Got %q.", validContexts, @@ -3974,27 +3984,30 @@ func readChatAgentModelOverrideContext( // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. -func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) { +func (api *API) getChatModelOverride(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) { httpapi.ResourceNotFound(rw) return } - overrideContext, ok := readChatAgentModelOverrideContext(rw, r) + overrideContext, ok := readChatModelOverrideContext(rw, r) if !ok { return } - modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext) + modelConfigID, isMalformed, label, err := api.readChatModelOverrideConfig(ctx, overrideContext) if err != nil { + if label == "" { + label = string(overrideContext) + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext), + Message: fmt.Sprintf("Internal error fetching %s model override.", label), Detail: err.Error(), }) return } - resp := codersdk.ChatAgentModelOverrideResponse{ + resp := codersdk.ChatModelOverrideResponse{ Context: overrideContext, ModelConfigID: formatChatModelOverride(modelConfigID), IsMalformed: isMalformed, @@ -4004,18 +4017,18 @@ func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques } // EXPERIMENTAL: this endpoint is experimental and is subject to change. -func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) { +func (api *API) putChatModelOverride(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { httpapi.Forbidden(rw) return } - overrideContext, ok := readChatAgentModelOverrideContext(rw, r) + overrideContext, ok := readChatModelOverrideContext(rw, r) if !ok { return } - var req codersdk.UpdateChatAgentModelOverrideRequest + var req codersdk.UpdateChatModelOverrideRequest if !httpapi.Read(ctx, rw, r, &req) { return } @@ -4035,9 +4048,13 @@ func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques return } - if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil { + label, err := api.upsertChatModelOverrideConfig(ctx, overrideContext, modelConfigID) + if err != nil { + if label == "" { + label = string(overrideContext) + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext), + Message: fmt.Sprintf("Internal error updating %s model override.", label), Detail: err.Error(), }) return diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 8d1b370d52..c93bf71cef 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -10061,28 +10061,28 @@ func TestChatModelOverrides(t *testing.T) { t.Parallel() type overrideResponse struct { - context codersdk.ChatAgentModelOverrideContext + context codersdk.ChatModelOverrideContext modelConfigID string isMalformed bool } type settingTest struct { name string - context codersdk.ChatAgentModelOverrideContext + context codersdk.ChatModelOverrideContext dbGet func(context.Context, database.Store) (string, error) dbUpsert func(context.Context, database.Store, string) error } - settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string { - return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext) + settingPath := func(overrideContext codersdk.ChatModelOverrideContext) string { + return "/api/experimental/chats/config/model-override/" + string(overrideContext) } getOverride := func( ctx context.Context, client *codersdk.ExperimentalClient, - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, ) (overrideResponse, error) { - resp, err := client.GetChatAgentModelOverride(ctx, overrideContext) + resp, err := client.GetChatModelOverride(ctx, overrideContext) if err != nil { return overrideResponse{}, err } @@ -10096,20 +10096,20 @@ func TestChatModelOverrides(t *testing.T) { putOverride := func( ctx context.Context, client *codersdk.ExperimentalClient, - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, modelConfigID string, ) error { - return client.UpdateChatAgentModelOverride( + return client.UpdateChatModelOverride( ctx, overrideContext, - codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID}, + codersdk.UpdateChatModelOverrideRequest{ModelConfigID: modelConfigID}, ) } settings := []settingTest{ { name: "General", - context: codersdk.ChatAgentModelOverrideContextGeneral, + context: codersdk.ChatModelOverrideContextGeneral, dbGet: func(ctx context.Context, db database.Store) (string, error) { return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx)) }, @@ -10119,7 +10119,7 @@ func TestChatModelOverrides(t *testing.T) { }, { name: "Explore", - context: codersdk.ChatAgentModelOverrideContextExplore, + context: codersdk.ChatModelOverrideContextExplore, dbGet: func(ctx context.Context, db database.Store) (string, error) { return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx)) }, @@ -10127,6 +10127,16 @@ func TestChatModelOverrides(t *testing.T) { return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value) }, }, + { + name: "TitleGeneration", + context: codersdk.ChatModelOverrideContextTitleGeneration, + dbGet: func(ctx context.Context, db database.Store) (string, error) { + return db.GetChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx)) + }, + dbUpsert: func(ctx context.Context, db database.Store, value string) error { + return db.UpsertChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx), value) + }, + }, } for _, setting := range settings { @@ -10265,23 +10275,23 @@ func TestChatModelOverrides(t *testing.T) { adminClient := newChatClient(t) coderdtest.CreateFirstUser(t, adminClient.Client) - unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context") + unknownContext := codersdk.ChatModelOverrideContext("not-a-context") _, err := getOverride(ctx, adminClient, unknownContext) sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message) + require.Equal(t, "Invalid chat model override context.", sdkErr.Message) require.Equal( t, - `Expected one of general, explore. Got "not-a-context".`, + `Expected one of general, explore, title_generation. Got "not-a-context".`, sdkErr.Detail, ) err = putOverride(ctx, adminClient, unknownContext, "") sdkErr = requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message) + require.Equal(t, "Invalid chat model override context.", sdkErr.Message) require.Equal( t, - `Expected one of general, explore. Got "not-a-context".`, + `Expected one of general, explore, title_generation. Got "not-a-context".`, sdkErr.Detail, ) }) @@ -10293,7 +10303,7 @@ func TestChatModelOverrides(t *testing.T) { firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) - unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context") + unknownContext := codersdk.ChatModelOverrideContext("not-a-context") _, err := getOverride(ctx, memberClient, unknownContext) requireSDKError(t, err, http.StatusNotFound) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index ad8ff0c896..97127e13b1 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -3110,6 +3110,26 @@ func (p *Server) resolveManualTitleModel( chat database.Chat, keys chatprovider.ProviderAPIKeys, ) (fantasy.LanguageModel, database.ChatModelConfig, error) { + overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( + ctx, + chat, + keys, + ) + if overrideErr != nil { + if overrideSet { + return nil, database.ChatModelConfig{}, xerrors.Errorf( + "resolve manual title generation model override: %w", + overrideErr, + ) + } + p.logger.Debug(ctx, "failed to resolve title generation model override for manual title", + slog.F("chat_id", chat.ID), + slog.Error(overrideErr), + ) + } else if overrideSet { + return overrideModel, overrideConfig, nil + } + configs, err := store.GetEnabledChatModelConfigs(ctx) if err != nil { p.logger.Debug(ctx, "failed to list manual title model configs", diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index c22a4785a5..896e3e8363 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -636,6 +636,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { LimitVal: manualTitleMessageWindowLimit, }, ).Return(nil, nil) + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil) gomock.InOrder( @@ -799,6 +800,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t LimitVal: manualTitleMessageWindowLimit, }, ).Return(nil, nil) + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil) gomock.InOrder( diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index 449637c03e..683be44dbe 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -106,9 +106,10 @@ type generatedTitle struct { // maybeGenerateChatTitle generates an AI title for the chat when // appropriate (first user message, no assistant reply yet, and the // current title is either empty or still the fallback truncation). -// It tries cheap, fast models first and falls back to the user's -// chat model. It is a best-effort operation that logs and swallows -// errors. +// It uses the configured title generation model override when set. +// Otherwise, it tries cheap, fast models first and falls back to the +// user's chat model. It is a best-effort operation that logs and +// swallows errors. func (p *Server) maybeGenerateChatTitle( ctx context.Context, chat database.Chat, @@ -130,28 +131,58 @@ func (p *Server) maybeGenerateChatTitle( titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - // Build candidate list: preferred lightweight models first, - // then the user's chat model as last resort. - candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1) - for _, c := range preferredTitleModels { - m, err := chatprovider.ModelFromConfig( - c.provider, c.model, keys, chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) - if err == nil { - candidates = append(candidates, shortTextCandidate{ - provider: c.provider, - model: c.model, - lm: m, - }) + overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( + titleCtx, + chat, + keys, + ) + if overrideErr != nil { + if overrideSet { + logger.Warn(ctx, "title generation model override unavailable, skipping title generation", + slog.F("chat_id", chat.ID), + slog.F("override_context", titleGenerationOverrideContext), + slog.Error(overrideErr), + ) + return } + logger.Debug(ctx, "failed to resolve title generation model override", + slog.F("chat_id", chat.ID), + slog.F("override_context", titleGenerationOverrideContext), + slog.Error(overrideErr), + ) + } + + var candidates []shortTextCandidate + if overrideSet { + candidates = []shortTextCandidate{{ + provider: overrideConfig.Provider, + model: overrideConfig.Model, + lm: overrideModel, + }} + } else { + // Build candidate list: preferred lightweight models first, + // then the user's chat model as last resort. + candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1) + for _, c := range preferredTitleModels { + m, err := chatprovider.ModelFromConfig( + c.provider, c.model, keys, chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + nil, + ) + if err == nil { + candidates = append(candidates, shortTextCandidate{ + provider: c.provider, + model: c.model, + lm: m, + }) + } + } + candidates = append(candidates, shortTextCandidate{ + provider: fallbackProvider, + model: fallbackModelName, + lm: fallbackModel, + }) } - candidates = append(candidates, shortTextCandidate{ - provider: fallbackProvider, - model: fallbackModelName, - lm: fallbackModel, - }) var historyTipMessageID int64 if len(messages) > 0 { @@ -197,10 +228,20 @@ func (p *Server) maybeGenerateChatTitle( finishDebugRun(err) if err != nil { lastErr = err - logger.Debug(ctx, "title model candidate failed", - slog.F("chat_id", chat.ID), - slog.Error(err), - ) + if overrideSet { + logger.Warn(ctx, "title model candidate failed", + slog.F("chat_id", chat.ID), + slog.F("override_context", titleGenerationOverrideContext), + slog.F("provider", candidate.provider), + slog.F("model", candidate.model), + slog.Error(err), + ) + } else { + logger.Debug(ctx, "title model candidate failed", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + } continue } if title == "" || title == chat.Title { @@ -225,10 +266,18 @@ func (p *Server) maybeGenerateChatTitle( } if lastErr != nil { - logger.Debug(ctx, "all title model candidates failed", - slog.F("chat_id", chat.ID), - slog.Error(lastErr), - ) + if overrideSet { + logger.Warn(ctx, "all title model candidates failed", + slog.F("chat_id", chat.ID), + slog.F("override_context", titleGenerationOverrideContext), + slog.Error(lastErr), + ) + } else { + logger.Debug(ctx, "all title model candidates failed", + slog.F("chat_id", chat.ID), + slog.Error(lastErr), + ) + } } } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 52617e8847..ccf0d01446 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -104,12 +104,12 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool { } func subagentModelOverrideLogLabel( - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, ) string { switch overrideContext { - case codersdk.ChatAgentModelOverrideContextGeneral: + case codersdk.ChatModelOverrideContextGeneral: return "general delegated child" - case codersdk.ChatAgentModelOverrideContextExplore: + case codersdk.ChatModelOverrideContextExplore: return "explore" default: return string(overrideContext) @@ -119,16 +119,16 @@ func subagentModelOverrideLogLabel( func readSubagentModelOverride( ctx context.Context, db database.Store, - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, ) (string, error) { switch overrideContext { - case codersdk.ChatAgentModelOverrideContextGeneral: + case codersdk.ChatModelOverrideContextGeneral: return db.GetChatGeneralModelOverride(ctx) - case codersdk.ChatAgentModelOverrideContextExplore: + case codersdk.ChatModelOverrideContextExplore: return db.GetChatExploreModelOverride(ctx) default: return "", xerrors.Errorf( - "unknown subagent model override context %q", + "unsupported subagent model override context %q", overrideContext, ) } @@ -167,6 +167,20 @@ func enabledProviderContainsName( return false } +type modelOverrideFailureMode int + +const ( + modelOverrideFailureModeSoft modelOverrideFailureMode = iota + modelOverrideFailureModeHard +) + +func modelOverrideErrorLabel(overrideContext string) string { + return strings.ReplaceAll(overrideContext, "_", " ") +} + +// resolveConfiguredModelOverride returns ok when a usable override is +// resolved. In hard failure mode, ok is also true for configured but unusable +// overrides so callers can distinguish them from unset or malformed values. func (p *Server) resolveConfiguredModelOverride( ctx context.Context, overrideContext string, @@ -174,6 +188,7 @@ func (p *Server) resolveConfiguredModelOverride( ownerID uuid.UUID, resolveModelConfig modelOverrideConfigResolver, resolveProviderKeys modelOverrideProviderKeysResolver, + failureMode modelOverrideFailureMode, ) (database.ChatModelConfig, bool, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { @@ -189,13 +204,40 @@ func (p *Server) resolveConfiguredModelOverride( ) return database.ChatModelConfig{}, false, nil } + modelConfig, providerName, err := resolveModelConfig( ctx, configuredModelConfigID, ) if err != nil { + if failureMode == modelOverrideFailureModeHard { + label := modelOverrideErrorLabel(overrideContext) + switch { + case errors.Is(err, sql.ErrNoRows): + return database.ChatModelConfig{}, true, xerrors.Errorf( + "%s model override is unavailable: %s", + label, + configuredModelConfigID, + ) + case errors.Is(err, errInvalidModelOverrideMetadata): + return database.ChatModelConfig{}, true, xerrors.Errorf( + "%s model override metadata is invalid for %s: %w", + label, + configuredModelConfigID, + err, + ) + default: + return database.ChatModelConfig{}, true, xerrors.Errorf( + "resolve %s model override %s: %w", + label, + configuredModelConfigID, + err, + ) + } + } + switch { - case xerrors.Is(err, sql.ErrNoRows): + case errors.Is(err, sql.ErrNoRows): p.logger.Info(ctx, "model override is unavailable, ignoring", slog.F("override_context", overrideContext), @@ -218,6 +260,7 @@ func (p *Server) resolveConfiguredModelOverride( } return database.ChatModelConfig{}, false, nil } + providerKeys, err := resolveProviderKeys(ctx, ownerID) if err != nil { return database.ChatModelConfig{}, false, xerrors.Errorf( @@ -228,6 +271,14 @@ func (p *Server) resolveConfiguredModelOverride( if providerKeys.APIKey(providerName) == "" && !(chatprovider.ProviderAllowsAmbientCredentials(providerName) && providerKeys.HasProvider(providerName)) { + if failureMode == modelOverrideFailureModeHard { + return database.ChatModelConfig{}, true, xerrors.Errorf( + "%s model override credentials are unavailable for provider %q", + modelOverrideErrorLabel(overrideContext), + providerName, + ) + } + p.logger.Info(ctx, "model override credentials are unavailable, ignoring", slog.F("override_context", overrideContext), @@ -242,7 +293,7 @@ func (p *Server) resolveConfiguredModelOverride( func (p *Server) resolveSubagentModelConfigID( ctx context.Context, ownerID uuid.UUID, - overrideContext codersdk.ChatAgentModelOverrideContext, + overrideContext codersdk.ChatModelOverrideContext, ) (uuid.UUID, error) { //nolint:gocritic // Chatd needs its scoped deployment-config read access here. chatdCtx := dbauthz.AsChatd(ctx) @@ -261,6 +312,7 @@ func (p *Server) resolveSubagentModelConfigID( ownerID, p.resolveModelConfigAndNormalizedProvider, p.resolveUserProviderAPIKeys, + modelOverrideFailureModeSoft, ) if err != nil { return uuid.Nil, err diff --git a/coderd/x/chatd/subagent_catalog.go b/coderd/x/chatd/subagent_catalog.go index 2b08f45045..eea0004475 100644 --- a/coderd/x/chatd/subagent_catalog.go +++ b/coderd/x/chatd/subagent_catalog.go @@ -48,7 +48,7 @@ func allSubagentDefinitions() []subagentDefinition { modelConfigID, err := p.resolveSubagentModelConfigID( ctx, parent.OwnerID, - codersdk.ChatAgentModelOverrideContextGeneral, + codersdk.ChatModelOverrideContextGeneral, ) if err != nil { return childSubagentChatOptions{}, err @@ -67,7 +67,7 @@ func allSubagentDefinitions() []subagentDefinition { modelConfigID, err := p.resolveSubagentModelConfigID( ctx, turnParent.OwnerID, - codersdk.ChatAgentModelOverrideContextExplore, + codersdk.ChatModelOverrideContextExplore, ) if err != nil { return childSubagentChatOptions{}, err diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index de4c6c60fc..3827a894e4 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -834,6 +834,7 @@ func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider( ByProvider: map[string]string{"bedrock": ""}, }, nil }, + modelOverrideFailureModeSoft, ) require.NoError(t, err) require.True(t, ok) diff --git a/coderd/x/chatd/title_override.go b/coderd/x/chatd/title_override.go new file mode 100644 index 0000000000..b01bc1613b --- /dev/null +++ b/coderd/x/chatd/title_override.go @@ -0,0 +1,100 @@ +package chatd + +import ( + "context" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +const titleGenerationOverrideContext = "title_generation" + +func readTitleGenerationModelOverride( + ctx context.Context, + db database.Store, +) (string, error) { + //nolint:gocritic // Chatd is internal, not a user, so this read uses AsChatd. + chatdCtx := dbauthz.AsChatd(ctx) + raw, err := db.GetChatTitleGenerationModelOverride(chatdCtx) + if err != nil { + return "", xerrors.Errorf( + "get chat title generation model override: %w", + err, + ) + } + return raw, nil +} + +// resolveTitleGenerationModelOverride resolves the deployment-wide title +// generation model override. It returns four values: +// +// - modelConfig and model: populated only on success. +// - overrideSet: true when the admin configured a non-empty override, +// regardless of whether resolution succeeded. Callers MUST always check +// err first; overrideSet alone does not imply the model is usable. +// - err: non-nil when resolution failed. DB read failure returns +// (zero, nil, false, err). With overrideSet=true, the override is +// configured but unusable (deleted model, missing credentials, etc.) and +// callers should treat this as a hard failure for explicit-override +// semantics, not a soft fallback. +// +// When the override is unset or stored as malformed, the function returns +// (zero, nil, false, nil) so callers can fall back to default behavior. +func (p *Server) resolveTitleGenerationModelOverride( + ctx context.Context, + chat database.Chat, + keys chatprovider.ProviderAPIKeys, +) (database.ChatModelConfig, fantasy.LanguageModel, bool, error) { + raw, err := readTitleGenerationModelOverride(ctx, p.db) + if err != nil { + return database.ChatModelConfig{}, nil, false, xerrors.Errorf( + "read title generation model override: %w", + err, + ) + } + + modelConfig, overrideSet, err := p.resolveConfiguredModelOverride( + ctx, + titleGenerationOverrideContext, + raw, + chat.OwnerID, + p.resolveModelConfigAndNormalizedProvider, + func(context.Context, uuid.UUID) (chatprovider.ProviderAPIKeys, error) { + return keys, nil + }, + modelOverrideFailureModeHard, + ) + if err != nil { + return database.ChatModelConfig{}, nil, overrideSet, err + } + if !overrideSet { + return database.ChatModelConfig{}, nil, false, nil + } + + model, err := chatprovider.ModelFromConfig( + modelConfig.Provider, + modelConfig.Model, + keys, + chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + nil, + ) + if err != nil { + return database.ChatModelConfig{}, nil, true, xerrors.Errorf( + "create title generation model override: %w", + err, + ) + } + if model == nil { + return database.ChatModelConfig{}, nil, true, xerrors.Errorf( + "create title generation model override returned nil", + ) + } + + return modelConfig, model, true, nil +} diff --git a/coderd/x/chatd/title_override_test.go b/coderd/x/chatd/title_override_test.go new file mode 100644 index 0000000000..145f3c91d1 --- /dev/null +++ b/coderd/x/chatd/title_override_test.go @@ -0,0 +1,559 @@ +package chatd //nolint:testpackage // Tests internal title override helpers. + +import ( + "context" + "database/sql" + "sync/atomic" + "testing" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chattest" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) { + t.Parallel() + + t.Run("uses preferred model before fallback", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + wantTitle := "Preferred title" + + var requestCount atomic.Int32 + serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requestCount.Add(1) + require.Equal(t, preferredTitleModels[1].model, req.Model) + return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`) + }) + keys := titleOverrideOpenAIKeys(serverURL) + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + t.Fatal("fallback model should not be called when preferred model works") + return nil, xerrors.New("unexpected fallback model call") + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) + db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chat.ID, + Title: wantTitle, + }).Return(chatWithTitle(chat, wantTitle), nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + keys, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), requestCount.Load()) + gotTitle, ok := generated.Load() + require.True(t, ok) + require.Equal(t, wantTitle, gotTitle) + }) + + t.Run("falls back to chat model when preferred models are unavailable", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + wantTitle := "Fallback title" + + var fallbackCalls atomic.Int32 + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + fallbackCalls.Add(1) + return &fantasy.ObjectResponse{ + Object: map[string]any{"title": wantTitle}, + }, nil + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) + db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chat.ID, + Title: wantTitle, + }).Return(chatWithTitle(chat, wantTitle), nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + chatprovider.ProviderAPIKeys{}, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), fallbackCalls.Load()) + gotTitle, ok := generated.Load() + require.True(t, ok) + require.Equal(t, wantTitle, gotTitle) + }) +} + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + wantTitle := "Fallback title" + + var fallbackCalls atomic.Int32 + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + fallbackCalls.Add(1) + return &fantasy.ObjectResponse{ + Object: map[string]any{"title": wantTitle}, + }, nil + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone) + db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chat.ID, + Title: wantTitle, + }).Return(chatWithTitle(chat, wantTitle), nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + chatprovider.ProviderAPIKeys{}, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), fallbackCalls.Load()) + gotTitle, ok := generated.Load() + require.True(t, ok) + require.Equal(t, wantTitle, gotTitle) +} + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + wantTitle := "Fallback title" + + var fallbackCalls atomic.Int32 + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + fallbackCalls.Add(1) + return &fantasy.ObjectResponse{ + Object: map[string]any{"title": wantTitle}, + }, nil + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("not-a-uuid", nil) + db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chat.ID, + Title: wantTitle, + }).Return(chatWithTitle(chat, wantTitle), nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + chatprovider.ProviderAPIKeys{}, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), fallbackCalls.Load()) + gotTitle, ok := generated.Load() + require.True(t, ok) + require.Equal(t, wantTitle, gotTitle) +} + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", true) + wantTitle := "Override title" + + var requestCount atomic.Int32 + serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requestCount.Add(1) + require.Equal(t, overrideConfig.Model, req.Model) + return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`) + }) + keys := titleOverrideOpenAIKeys(serverURL) + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + t.Fatal("fallback model should not be called when override is usable") + return nil, xerrors.New("unexpected fallback model call") + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chat.ID, + Title: wantTitle, + }).Return(chatWithTitle(chat, wantTitle), nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + keys, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), requestCount.Load()) + gotTitle, ok := generated.Load() + require.True(t, ok) + require.Equal(t, wantTitle, gotTitle) +} + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", false) + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + t.Fatal("fallback model should not be called when override is unusable") + return nil, xerrors.New("unexpected fallback model call") + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + chatprovider.ProviderAPIKeys{}, + generated, + logger, + nil, + ) + + _, ok := generated.Load() + require.False(t, ok) +} + +func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, messages := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", true) + + var requestCount atomic.Int32 + serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requestCount.Add(1) + require.Equal(t, overrideConfig.Model, req.Model) + return chattest.OpenAINonStreamingResponse(`{"title":""}`) + }) + keys := titleOverrideOpenAIKeys(serverURL) + fallbackModel := &chattest.FakeModel{ + GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + t.Fatal("fallback model should not be called after override call failure") + return nil, xerrors.New("unexpected fallback model call") + }, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + + generated := &generatedChatTitle{} + server := titleOverrideTestServer(db, logger) + server.maybeGenerateChatTitle( + ctx, + chat, + messages, + "openai", + "fallback-chat-model", + fallbackModel, + keys, + generated, + logger, + nil, + ) + + require.Equal(t, int32(1), requestCount.Load()) + _, ok := generated.Load() + require.False(t, ok) +} + +func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + preferredConfig := database.ChatModelConfig{ + ID: uuid.New(), + Provider: preferredTitleModels[1].provider, + Model: preferredTitleModels[1].model, + Enabled: true, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) + db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{ + {Provider: "openai", Model: "gpt-4.1", Enabled: true}, + preferredConfig, + }, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + ) + require.NoError(t, err) + require.NotNil(t, model) + require.Equal(t, preferredConfig, gotConfig) +} + +func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + preferredConfig := database.ChatModelConfig{ + ID: uuid.New(), + Provider: preferredTitleModels[1].provider, + Model: preferredTitleModels[1].model, + Enabled: true, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone) + db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{ + {Provider: "openai", Model: "gpt-4.1", Enabled: true}, + preferredConfig, + }, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + ) + require.NoError(t, err) + require.NotNil(t, model) + require.Equal(t, preferredConfig, gotConfig) +} + +func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", true) + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + ) + require.NoError(t, err) + require.NotNil(t, model) + require.Equal(t, overrideConfig, gotConfig) +} + +func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", true) + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{}, + ) + require.Error(t, err) + require.ErrorContains(t, err, "resolve manual title generation model override") + require.ErrorContains(t, err, "credentials are unavailable") + require.Nil(t, model) + require.Equal(t, database.ChatModelConfig{}, gotConfig) +} + +func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + overrideConfig := titleOverrideModelConfig("gpt-4.1", false) + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) + db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + ) + require.Error(t, err) + require.ErrorContains(t, err, "resolve manual title generation model override") + require.ErrorContains(t, err, "title generation model override is unavailable") + require.Nil(t, model) + require.Equal(t, database.ChatModelConfig{}, gotConfig) +} + +func titleOverrideTestChatAndMessages(t *testing.T) (database.Chat, []database.ChatMessage) { + t.Helper() + + userPrompt := "review pull request 123 and fix comments" + chat := database.Chat{ + ID: uuid.New(), + OwnerID: uuid.New(), + Title: fallbackChatTitle(userPrompt), + } + message := mustChatMessage( + t, + database.ChatMessageRoleUser, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText(userPrompt), + ) + message.ID = 1 + return chat, []database.ChatMessage{message} +} + +func titleOverrideTestServer(db database.Store, logger slog.Logger) *Server { + return &Server{ + db: db, + logger: logger, + configCache: newChatConfigCache(context.Background(), db, quartz.NewReal()), + } +} + +func titleOverrideModelConfig(model string, enabled bool) database.ChatModelConfig { + return database.ChatModelConfig{ + ID: uuid.New(), + Provider: "openai", + Model: model, + Enabled: enabled, + } +} + +func titleOverrideOpenAIKeys(serverURL string) chatprovider.ProviderAPIKeys { + return chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + "openai": "test-key", + }, + BaseURLByProvider: map[string]string{ + "openai": serverURL, + }, + } +} + +func chatWithTitle(chat database.Chat, title string) database.Chat { + chat.Title = title + return chat +} diff --git a/codersdk/chats.go b/codersdk/chats.go index 66c4d3f3b8..1d8682b733 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -562,45 +562,48 @@ type UpdateChatPlanModeInstructionsRequest struct { PlanModeInstructions string `json:"plan_mode_instructions"` } -// ChatAgentModelOverrideContext identifies which chat or subagent context -// a deployment override applies to. -type ChatAgentModelOverrideContext string +// ChatModelOverrideContext identifies which chat model override context a +// deployment override applies to. +type ChatModelOverrideContext string const ( - ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general" - ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore" + ChatModelOverrideContextGeneral ChatModelOverrideContext = "general" + ChatModelOverrideContextExplore ChatModelOverrideContext = "explore" + ChatModelOverrideContextTitleGeneration ChatModelOverrideContext = "title_generation" ) // Valid reports whether the override context is one of the supported values. -func (c ChatAgentModelOverrideContext) Valid() bool { +func (c ChatModelOverrideContext) Valid() bool { switch c { - case ChatAgentModelOverrideContextGeneral, - ChatAgentModelOverrideContextExplore: + case ChatModelOverrideContextGeneral, + ChatModelOverrideContextExplore, + ChatModelOverrideContextTitleGeneration: return true default: return false } } -// AllChatAgentModelOverrideContexts returns all supported override contexts. -func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext { - return []ChatAgentModelOverrideContext{ - ChatAgentModelOverrideContextGeneral, - ChatAgentModelOverrideContextExplore, +// AllChatModelOverrideContexts returns all supported override contexts. +func AllChatModelOverrideContexts() []ChatModelOverrideContext { + return []ChatModelOverrideContext{ + ChatModelOverrideContextGeneral, + ChatModelOverrideContextExplore, + ChatModelOverrideContextTitleGeneration, } } -// ChatAgentModelOverrideResponse is the response body for the chat agent -// model override configuration endpoint. -type ChatAgentModelOverrideResponse struct { - Context ChatAgentModelOverrideContext `json:"context"` - ModelConfigID string `json:"model_config_id"` - IsMalformed bool `json:"is_malformed"` +// ChatModelOverrideResponse is the response body for the chat model override +// configuration endpoint. +type ChatModelOverrideResponse struct { + Context ChatModelOverrideContext `json:"context"` + ModelConfigID string `json:"model_config_id"` + IsMalformed bool `json:"is_malformed"` } -// UpdateChatAgentModelOverrideRequest is the request body for updating the -// chat agent model override configuration endpoint. -type UpdateChatAgentModelOverrideRequest struct { +// UpdateChatModelOverrideRequest is the request body for updating the chat +// model override configuration endpoint. +type UpdateChatModelOverrideRequest struct { ModelConfigID string `json:"model_config_id"` } @@ -2098,30 +2101,30 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context, return nil } -// GetChatAgentModelOverride returns the deployment-wide chat agent model -// override for the requested context. -func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) { +// GetChatModelOverride returns the deployment-wide chat model override for +// the requested context. +func (c *ExperimentalClient) GetChatModelOverride(ctx context.Context, override ChatModelOverrideContext) (ChatModelOverrideResponse, error) { path := fmt.Sprintf( - "/api/experimental/chats/config/agent-model-override/%s", + "/api/experimental/chats/config/model-override/%s", url.PathEscape(string(override)), ) res, err := c.Request(ctx, http.MethodGet, path, nil) if err != nil { - return ChatAgentModelOverrideResponse{}, err + return ChatModelOverrideResponse{}, err } defer res.Body.Close() if res.StatusCode != http.StatusOK { - return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res) + return ChatModelOverrideResponse{}, ReadBodyAsError(res) } - var resp ChatAgentModelOverrideResponse + var resp ChatModelOverrideResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } -// UpdateChatAgentModelOverride updates the deployment-wide chat agent model -// override for the requested context. -func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error { +// UpdateChatModelOverride updates the deployment-wide chat model override for +// the requested context. +func (c *ExperimentalClient) UpdateChatModelOverride(ctx context.Context, override ChatModelOverrideContext, req UpdateChatModelOverrideRequest) error { path := fmt.Sprintf( - "/api/experimental/chats/config/agent-model-override/%s", + "/api/experimental/chats/config/model-override/%s", url.PathEscape(string(override)), ) res, err := c.Request(ctx, http.MethodPut, path, req) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 06d3437edf..58b29eb768 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3262,22 +3262,21 @@ class ExperimentalApiMethods { ); }; - getChatAgentModelOverride = async ( - context: TypesGen.ChatAgentModelOverrideContext, - ): Promise => { - const response = - await this.axios.get( - `/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`, - ); + getChatModelOverride = async ( + context: TypesGen.ChatModelOverrideContext, + ): Promise => { + const response = await this.axios.get( + `/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`, + ); return response.data; }; - updateChatAgentModelOverride = async ( - context: TypesGen.ChatAgentModelOverrideContext, - req: TypesGen.UpdateChatAgentModelOverrideRequest, + updateChatModelOverride = async ( + context: TypesGen.ChatModelOverrideContext, + req: TypesGen.UpdateChatModelOverrideRequest, ): Promise => { await this.axios.put( - `/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`, + `/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`, req, ); }; diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 7c276f78f9..e395c71108 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1320,25 +1320,6 @@ export interface Chat { readonly children: readonly Chat[]; } -// From codersdk/chats.go -export type ChatAgentModelOverrideContext = "explore" | "general"; - -export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [ - "explore", - "general", -]; - -// From codersdk/chats.go -/** - * ChatAgentModelOverrideResponse is the response body for the chat agent - * model override configuration endpoint. - */ -export interface ChatAgentModelOverrideResponse { - readonly context: ChatAgentModelOverrideContext; - readonly model_config_id: string; - readonly is_malformed: boolean; -} - // From codersdk/chats.go /** * ChatAutoArchiveDaysResponse contains the current chat auto-archive setting. @@ -2095,6 +2076,29 @@ export interface ChatModelOpenRouterProviderOptions { readonly provider?: ChatModelOpenRouterProvider; } +// From codersdk/chats.go +export type ChatModelOverrideContext = + | "explore" + | "general" + | "title_generation"; + +export const ChatModelOverrideContexts: ChatModelOverrideContext[] = [ + "explore", + "general", + "title_generation", +]; + +// From codersdk/chats.go +/** + * ChatModelOverrideResponse is the response body for the chat model override + * configuration endpoint. + */ +export interface ChatModelOverrideResponse { + readonly context: ChatModelOverrideContext; + readonly model_config_id: string; + readonly is_malformed: boolean; +} + // From codersdk/chats.go /** * ChatModelProvider represents provider availability and model results. @@ -7804,15 +7808,6 @@ export interface UpdateAppearanceConfig { readonly announcement_banners: readonly BannerConfig[]; } -// From codersdk/chats.go -/** - * UpdateChatAgentModelOverrideRequest is the request body for updating the - * chat agent model override configuration endpoint. - */ -export interface UpdateChatAgentModelOverrideRequest { - readonly model_config_id: string; -} - // From codersdk/chats.go /** * UpdateChatAutoArchiveDaysRequest is a request to update the chat @@ -7854,6 +7849,15 @@ export interface UpdateChatModelConfigRequest { readonly model_config?: ChatModelCallConfig; } +// From codersdk/chats.go +/** + * UpdateChatModelOverrideRequest is the request body for updating the chat + * model override configuration endpoint. + */ +export interface UpdateChatModelOverrideRequest { + readonly model_config_id: string; +} + // From codersdk/chats.go /** * UpdateChatPlanModeInstructionsRequest is the request body for diff --git a/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx b/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx index b406a173ea..59d13ce8ce 100644 --- a/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx @@ -12,31 +12,30 @@ import { useAuthenticated } from "#/hooks/useAuthenticated"; import { RequirePermission } from "#/modules/permissions/RequirePermission"; import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView"; -const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext = - "general"; -const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext = - "explore"; +const generalOverrideContext: TypesGen.ChatModelOverrideContext = "general"; +const exploreOverrideContext: TypesGen.ChatModelOverrideContext = "explore"; +const titleGenerationOverrideContext: TypesGen.ChatModelOverrideContext = + "title_generation"; -const chatAgentModelOverrideKey = ( - context: TypesGen.ChatAgentModelOverrideContext, -) => ["chat-agent-model-override", context] as const; +const chatModelOverrideKey = (context: TypesGen.ChatModelOverrideContext) => + ["chat-model-override", context] as const; -const chatAgentModelOverrideQuery = ( - context: TypesGen.ChatAgentModelOverrideContext, +const chatModelOverrideQuery = ( + context: TypesGen.ChatModelOverrideContext, ) => ({ - queryKey: chatAgentModelOverrideKey(context), - queryFn: () => API.experimental.getChatAgentModelOverride(context), + queryKey: chatModelOverrideKey(context), + queryFn: () => API.experimental.getChatModelOverride(context), }); -const updateChatAgentModelOverrideMutation = ( +const updateChatModelOverrideMutation = ( queryClient: QueryClient, - context: TypesGen.ChatAgentModelOverrideContext, + context: TypesGen.ChatModelOverrideContext, ) => ({ - mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) => - API.experimental.updateChatAgentModelOverride(context, req), + mutationFn: (req: TypesGen.UpdateChatModelOverrideRequest) => + API.experimental.updateChatModelOverride(context, req), onSuccess: async () => { await queryClient.invalidateQueries({ - queryKey: chatAgentModelOverrideKey(context), + queryKey: chatModelOverrideKey(context), exact: true, }); }, @@ -48,25 +47,36 @@ const AgentSettingsAgentsPage: FC = () => { const canEditDeploymentConfig = permissions.editDeploymentConfig; const generalModelOverrideQuery = useQuery({ - ...chatAgentModelOverrideQuery(generalOverrideContext), + ...chatModelOverrideQuery(generalOverrideContext), enabled: canEditDeploymentConfig, }); const exploreModelOverrideQuery = useQuery({ - ...chatAgentModelOverrideQuery(exploreOverrideContext), + ...chatModelOverrideQuery(exploreOverrideContext), + enabled: canEditDeploymentConfig, + }); + const titleGenerationModelQuery = useQuery({ + ...chatModelOverrideQuery(titleGenerationOverrideContext), enabled: canEditDeploymentConfig, }); const modelConfigsQuery = useQuery(chatModelConfigs()); const saveGeneralModelOverrideMutation = useMutation( - updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext), + updateChatModelOverrideMutation(queryClient, generalOverrideContext), + ); + const saveTitleGenerationModelMutation = useMutation( + updateChatModelOverrideMutation( + queryClient, + titleGenerationOverrideContext, + ), ); const saveExploreModelOverrideMutation = useMutation( - updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext), + updateChatModelOverrideMutation(queryClient, exploreOverrideContext), ); return ( { isSaveGeneralModelOverrideError={ saveGeneralModelOverrideMutation.isError } + onSaveTitleGenerationModel={saveTitleGenerationModelMutation.mutate} + isSavingTitleGenerationModel={ + saveTitleGenerationModelMutation.isPending + } + isSaveTitleGenerationModelError={ + saveTitleGenerationModelMutation.isError + } onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate} isSavingExploreModelOverride={ saveExploreModelOverrideMutation.isPending diff --git a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.stories.tsx b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.stories.tsx index a9158a1d85..457cd2a9c8 100644 --- a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.stories.tsx @@ -10,6 +10,8 @@ const OVERRIDE_MALFORMED_WARNING = "The saved override is malformed and is being treated as unset. Click Save to clear it."; const UNAVAILABLE_SAVED_MODEL_WARNING = "The saved model is no longer enabled and will be ignored until you choose a new override."; +const TITLE_UNAVAILABLE_SAVED_MODEL_WARNING = + "The selected model is currently unavailable. Title generation will be skipped until you choose another model or clear this setting."; const buildModelConfig = ( overrides: Partial, @@ -28,15 +30,20 @@ const buildModelConfig = ( }); const buildOverrideData = ( - context: TypesGen.ChatAgentModelOverrideContext, - overrides: Partial = {}, -): TypesGen.ChatAgentModelOverrideResponse => ({ + context: TypesGen.ChatModelOverrideContext, + overrides: Partial = {}, +): TypesGen.ChatModelOverrideResponse => ({ context, model_config_id: "", is_malformed: false, ...overrides, }); +const buildTitleGenerationModelOverrideData = ( + overrides: Partial = {}, +): TypesGen.ChatModelOverrideResponse => + buildOverrideData("title_generation", overrides); + const generalModelConfig = buildModelConfig({ id: "model-general-gpt-4.1-mini", display_name: "GPT 4.1 Mini", @@ -50,6 +57,13 @@ const claudeSonnetModelConfig = buildModelConfig({ context_limit: 200_000, }); +const titleModelConfig = buildModelConfig({ + id: "model-title-gpt-4o-mini", + model: "gpt-4o-mini", + display_name: "GPT 4o Mini", + context_limit: 128_000, +}); + const exploreFallbackModelConfig = buildModelConfig({ id: "model-explore-blank-display", provider: "anthropic", @@ -65,6 +79,14 @@ const generalDisabledModelConfig = buildModelConfig({ enabled: false, }); +const titleDisabledModelConfig = buildModelConfig({ + id: "model-title-disabled", + model: "gpt-4o-mini-legacy", + display_name: "GPT 4o Mini Legacy", + enabled: false, + context_limit: 128_000, +}); + const exploreDisabledModelConfig = buildModelConfig({ id: "model-explore-disabled", provider: "anthropic", @@ -77,8 +99,10 @@ const exploreDisabledModelConfig = buildModelConfig({ const allModelConfigs: TypesGen.ChatModelConfig[] = [ generalModelConfig, claudeSonnetModelConfig, + titleModelConfig, exploreFallbackModelConfig, generalDisabledModelConfig, + titleDisabledModelConfig, exploreDisabledModelConfig, ]; @@ -86,6 +110,7 @@ const makeArgs = ( overrides: Partial = {}, ): AgentSettingsAgentsPageViewProps => ({ generalModelOverrideData: buildOverrideData("general"), + titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData(), exploreModelOverrideData: buildOverrideData("explore"), modelConfigsData: allModelConfigs, modelConfigsError: undefined, @@ -93,6 +118,9 @@ const makeArgs = ( onSaveGeneralModelOverride: fn(), isSavingGeneralModelOverride: false, isSaveGeneralModelOverrideError: false, + onSaveTitleGenerationModel: fn(), + isSavingTitleGenerationModel: false, + isSaveTitleGenerationModelError: false, onSaveExploreModelOverride: fn(), isSavingExploreModelOverride: false, isSaveExploreModelOverrideError: false, @@ -146,13 +174,28 @@ export const AllOverridesUnset: Story = { const headings = await canvas.findAllByRole("heading", { level: 3 }); expect(headings.map((heading) => heading.textContent?.trim())).toEqual([ "General model", + "Title generation model", "Explore subagent model", ]); + await canvas.findByText( + "Choose a model for generated chat titles. Leave unset to use Coder's default title algorithm, which currently tries fast title models for configured providers first, for example Claude Haiku, GPT-4o mini, and Gemini Flash, then falls back to the chat's current model. When a model is selected here, Coder uses only that model for title generation. Recommended title models are fast and low cost.", + ); - for (const headingName of ["General model", "Explore subagent model"]) { + const unsetSections = [ + { headingName: "General model", placeholder: "Use chat default" }, + { + headingName: "Title generation model", + placeholder: "Use title default", + }, + { + headingName: "Explore subagent model", + placeholder: "Use chat default", + }, + ]; + for (const { headingName, placeholder } of unsetSections) { const section = await getSection(canvasElement, headingName); expect( - within(section).getByRole("combobox", { name: "Use chat default" }), + within(section).getByRole("combobox", { name: placeholder }), ).toBeInTheDocument(); expect( within(section).getByRole("button", { name: "Save" }), @@ -166,12 +209,19 @@ export const EachOverrideSetToEnabledModel: Story = { generalModelOverrideData: buildOverrideData("general", { model_config_id: generalModelConfig.id, }), + titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({ + model_config_id: titleModelConfig.id, + }), exploreModelOverrideData: buildOverrideData("explore", { model_config_id: exploreFallbackModelConfig.id, }), }), play: async ({ canvasElement, args }) => { const generalSection = await getSection(canvasElement, "General model"); + const titleSection = await getSection( + canvasElement, + "Title generation model", + ); const exploreSection = await getSection( canvasElement, "Explore subagent model", @@ -183,6 +233,12 @@ export const EachOverrideSetToEnabledModel: Story = { }), ).toHaveTextContent("claude-sonnet-4-20250514"); + expect( + within(titleSection).getByRole("combobox", { + name: /gpt 4o mini/i, + }), + ).toHaveTextContent("GPT 4o Mini"); + await selectModelInSection( generalSection, canvasElement, @@ -203,6 +259,26 @@ export const EachOverrideSetToEnabledModel: Story = { ); }); + await selectModelInSection( + titleSection, + canvasElement, + /gpt 4o mini/i, + "Claude Sonnet 4", + ); + const titleSaveButton = within(titleSection).getByRole("button", { + name: "Save", + }); + await waitFor(() => { + expect(titleSaveButton).toBeEnabled(); + }); + await userEvent.click(titleSaveButton); + await waitFor(() => { + expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith( + { model_config_id: claudeSonnetModelConfig.id }, + expect.anything(), + ); + }); + const exploreClearButton = within(exploreSection).getByRole("button", { name: "Clear", }); @@ -228,18 +304,25 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = { generalModelOverrideData: buildOverrideData("general", { is_malformed: true, }), + titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({ + is_malformed: true, + }), exploreModelOverrideData: buildOverrideData("explore", { is_malformed: true, }), }), play: async ({ canvasElement, args }) => { const generalSection = await getSection(canvasElement, "General model"); + const titleSection = await getSection( + canvasElement, + "Title generation model", + ); const exploreSection = await getSection( canvasElement, "Explore subagent model", ); - for (const section of [generalSection, exploreSection]) { + for (const section of [generalSection, titleSection, exploreSection]) { await within(section).findByText(OVERRIDE_MALFORMED_WARNING); } @@ -257,6 +340,20 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = { ); }); + const titleSaveButton = within(titleSection).getByRole("button", { + name: "Save", + }); + await waitFor(() => { + expect(titleSaveButton).toBeEnabled(); + }); + await userEvent.click(titleSaveButton); + await waitFor(() => { + expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith( + { model_config_id: "" }, + expect.anything(), + ); + }); + const exploreSaveButton = within(exploreSection).getByRole("button", { name: "Save", }); @@ -278,12 +375,19 @@ export const UnavailableSavedModels: Story = { generalModelOverrideData: buildOverrideData("general", { model_config_id: generalDisabledModelConfig.id, }), + titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({ + model_config_id: titleDisabledModelConfig.id, + }), exploreModelOverrideData: buildOverrideData("explore", { model_config_id: exploreDisabledModelConfig.id, }), }), play: async ({ canvasElement }) => { const generalSection = await getSection(canvasElement, "General model"); + const titleSection = await getSection( + canvasElement, + "Title generation model", + ); const exploreSection = await getSection( canvasElement, "Explore subagent model", @@ -295,5 +399,13 @@ export const UnavailableSavedModels: Story = { within(section).getByRole("combobox", { name: "Unavailable model" }), ).toBeInTheDocument(); } + await within(titleSection).findByText( + TITLE_UNAVAILABLE_SAVED_MODEL_WARNING, + ); + expect( + within(titleSection).getByRole("combobox", { + name: "Unavailable model", + }), + ).toBeInTheDocument(); }, }; diff --git a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx index 94006dd8e9..af332b78ed 100644 --- a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx @@ -7,19 +7,23 @@ import { } from "./components/SubagentModelOverrideSettings"; type SaveModelOverride = ( - req: TypesGen.UpdateChatAgentModelOverrideRequest, + req: { readonly model_config_id: string }, options?: MutationCallbacks, ) => void; export interface AgentSettingsAgentsPageViewProps { - generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse; - exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse; + generalModelOverrideData?: TypesGen.ChatModelOverrideResponse; + titleGenerationModelOverrideData?: TypesGen.ChatModelOverrideResponse; + exploreModelOverrideData?: TypesGen.ChatModelOverrideResponse; modelConfigsData: TypesGen.ChatModelConfig[] | undefined; modelConfigsError: unknown; isLoadingModelConfigs: boolean; onSaveGeneralModelOverride?: SaveModelOverride; isSavingGeneralModelOverride?: boolean; isSaveGeneralModelOverrideError?: boolean; + onSaveTitleGenerationModel: SaveModelOverride; + isSavingTitleGenerationModel: boolean; + isSaveTitleGenerationModelError: boolean; onSaveExploreModelOverride: SaveModelOverride; isSavingExploreModelOverride: boolean; isSaveExploreModelOverrideError: boolean; @@ -29,6 +33,7 @@ export const AgentSettingsAgentsPageView: FC< AgentSettingsAgentsPageViewProps > = ({ generalModelOverrideData, + titleGenerationModelOverrideData, exploreModelOverrideData, modelConfigsData, modelConfigsError, @@ -36,6 +41,9 @@ export const AgentSettingsAgentsPageView: FC< onSaveGeneralModelOverride, isSavingGeneralModelOverride = false, isSaveGeneralModelOverrideError = false, + onSaveTitleGenerationModel, + isSavingTitleGenerationModel, + isSaveTitleGenerationModelError, onSaveExploreModelOverride, isSavingExploreModelOverride, isSaveExploreModelOverrideError, @@ -77,6 +85,31 @@ export const AgentSettingsAgentsPageView: FC< /> )} +
+ + +
( model_config_id: "", is_malformed: false, }} + titleGenerationModelOverrideData={{ + context: "title_generation", + model_config_id: "", + is_malformed: false, + }} modelConfigsData={[]} modelConfigsError={undefined} isLoadingModelConfigs={false} + onSaveTitleGenerationModel={fn()} + isSavingTitleGenerationModel={false} + isSaveTitleGenerationModelError={false} onSaveExploreModelOverride={fn()} isSavingExploreModelOverride={false} isSaveExploreModelOverrideError={false} diff --git a/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx b/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx index 9e46313ea1..bd1009c317 100644 --- a/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx +++ b/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx @@ -22,7 +22,7 @@ interface UpdateModelOverrideRequest { interface SubagentModelOverrideSettingsProps { title: string; - description: ReactNode; + description?: ReactNode; modelOverrideData: ModelOverrideData | undefined; enabledModelConfigs: readonly TypesGen.ChatModelConfig[]; modelConfigsError: unknown; @@ -34,6 +34,8 @@ interface SubagentModelOverrideSettingsProps { isSaving: boolean; isSaveError: boolean; saveErrorMessage: string; + unsetPlaceholder?: string; + unavailableModelWarning?: string; showHeader?: boolean; disabled?: boolean; } @@ -61,6 +63,8 @@ export const SubagentModelOverrideSettings: FC< isSaving, isSaveError, saveErrorMessage, + unsetPlaceholder = "Use chat default", + unavailableModelWarning = "The saved model is no longer enabled and will be ignored until you choose a new override.", showHeader = true, disabled = false, }) => { @@ -104,9 +108,11 @@ export const SubagentModelOverrideSettings: FC<

{title}

-

- {description} -

+ {description && ( +

+ {description} +

+ )} )} form.setFieldValue("model_config_id", value)} disabled={isModelOverrideDisabled} placeholder={ - isUnavailableSavedModel ? "Unavailable model" : "Use chat default" + isUnavailableSavedModel ? "Unavailable model" : unsetPlaceholder } emptyMessage={ isLoading ? "Loading models..." : "No enabled models found." @@ -125,10 +131,7 @@ export const SubagentModelOverrideSettings: FC< /> {isUnavailableSavedModel && ( - - The saved model is no longer enabled and will be ignored until you - choose a new override. - + {unavailableModelWarning} )} {isMalformedOverride && (