mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add admin-configurable chat title generation model (#24838)
Adds an admin-configurable deployment-wide setting that controls which
model is used for chat title generation. Admins can pick any enabled
chat model config from the Agents settings page, or leave the setting
unset to keep the existing fast-models-then-chat-model fallback
algorithm.
When a model is selected, both automatic and manual title generation use
only that model, with no silent fallback. When the configured model is
disabled, missing credentials, or otherwise unusable, automatic title
generation skips entirely (best-effort) and manual title regeneration
returns a clear error, so admins notice the misconfiguration instead of
silently routing title traffic through another provider.
## Surface
- New deployment-wide setting stored as a `site_configs` row
(`agents_chat_title_generation_model_override`).
- New experimental endpoint `GET/PUT
/api/experimental/chats/config/model-override/{context}`.
- Frontend: title generation now appears as a third dropdown on the
Agents admin settings page alongside the existing general and explore
context overrides.
## DRY refactors folded in
Title generation is integrated as a third value of the existing
`ChatModelOverrideContext` type alongside `general` and `explore`,
sharing the parameterized HTTP route, SDK methods, generated types, and
frontend API plumbing rather than introducing a parallel surface. The
`Agent` prefix was dropped from the type and route since title
generation is not a delegated agent.
The chatd model-override resolver is also shared.
`resolveConfiguredModelOverride` now takes a `failureMode` parameter:
- Subagent overrides use soft failure: misconfigured overrides are
logged and the parent model is used.
- Title generation uses hard failure: misconfigured overrides return an
explicit error so manual title regeneration surfaces the
misconfiguration and automatic title generation skips instead of
silently falling back.
> Mux is acting on Mike's behalf.
This commit is contained in:
+2
-2
@@ -1192,8 +1192,8 @@ func New(options *Options) *API {
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions)
|
||||
r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions)
|
||||
r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride)
|
||||
r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride)
|
||||
r.Get("/model-override/{context}", api.getChatModelOverride)
|
||||
r.Put("/model-override/{context}", api.putChatModelOverride)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/debug-logging", api.getChatDebugLogging)
|
||||
|
||||
@@ -2967,6 +2967,13 @@ func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
||||
return q.db.GetChatTemplateAllowlist(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetChatTitleGenerationModelOverride(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
@@ -7517,6 +7524,13 @@ func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllow
|
||||
return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatTitleGenerationModelOverride(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
|
||||
@@ -918,6 +918,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatPlanModeInstructions(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
@@ -1237,6 +1241,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatTitleGenerationModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatTitleGenerationModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatPlanModeInstructions(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
|
||||
@@ -1456,6 +1456,14 @@ func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatTitleGenerationModelOverride(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTitleGenerationModelOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||
@@ -5408,6 +5416,14 @@ func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, temp
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatTitleGenerationModelOverride(ctx, value)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatTitleGenerationModelOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTitleGenerationModelOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
|
||||
@@ -2687,6 +2687,21 @@ func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx)
|
||||
}
|
||||
|
||||
// GetChatTitleGenerationModelOverride mocks base method.
|
||||
func (m *MockStore) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatTitleGenerationModelOverride", ctx)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatTitleGenerationModelOverride indicates an expected call of GetChatTitleGenerationModelOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatTitleGenerationModelOverride(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatTitleGenerationModelOverride), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -10152,6 +10167,20 @@ func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowl
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist)
|
||||
}
|
||||
|
||||
// UpsertChatTitleGenerationModelOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatTitleGenerationModelOverride", ctx, value)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatTitleGenerationModelOverride indicates an expected call of UpsertChatTitleGenerationModelOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatTitleGenerationModelOverride(ctx, value any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTitleGenerationModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatTitleGenerationModelOverride), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -362,6 +362,7 @@ type sqlcQuerier interface {
|
||||
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
||||
GetChatTitleGenerationModelOverride(ctx context.Context) (string, error)
|
||||
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
||||
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
||||
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
||||
@@ -1206,6 +1207,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
|
||||
@@ -20731,6 +20731,18 @@ func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, erro
|
||||
return template_allowlist, err
|
||||
}
|
||||
|
||||
const getChatTitleGenerationModelOverride = `-- name: GetChatTitleGenerationModelOverride :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatTitleGenerationModelOverride(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatTitleGenerationModelOverride)
|
||||
var model_config_id string
|
||||
err := row.Scan(&model_config_id)
|
||||
return model_config_id, err
|
||||
}
|
||||
|
||||
const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one
|
||||
SELECT
|
||||
COALESCE(
|
||||
@@ -21085,6 +21097,16 @@ func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAl
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatTitleGenerationModelOverride = `-- name: UpsertChatTitleGenerationModelOverride :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UpsertChatTitleGenerationModelOverride(ctx context.Context, value string) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertChatTitleGenerationModelOverride, value)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_workspace_ttl', $1::text)
|
||||
|
||||
@@ -183,6 +183,14 @@ SELECT
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override';
|
||||
|
||||
-- name: GetChatTitleGenerationModelOverride :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_title_generation_model_override'), '') :: text AS model_config_id;
|
||||
|
||||
-- name: UpsertChatTitleGenerationModelOverride :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_title_generation_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_title_generation_model_override';
|
||||
|
||||
-- name: GetChatDesktopEnabled :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
||||
|
||||
+58
-41
@@ -532,62 +532,72 @@ func (api *API) getChatModelOverrideConfig(
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) {
|
||||
overrideContext := codersdk.ChatAgentModelOverrideContext(raw)
|
||||
func parseChatModelOverrideContext(raw string) (codersdk.ChatModelOverrideContext, error) {
|
||||
overrideContext := codersdk.ChatModelOverrideContext(raw)
|
||||
if overrideContext.Valid() {
|
||||
return overrideContext, nil
|
||||
}
|
||||
return "", xerrors.Errorf("unknown chat agent model override context %q", raw)
|
||||
return "", xerrors.Errorf("unknown chat model override context %q", raw)
|
||||
}
|
||||
|
||||
type chatAgentModelOverrideSiteConfig struct {
|
||||
type chatModelOverrideSiteConfig struct {
|
||||
label string
|
||||
getter func(context.Context) (string, error)
|
||||
upsert func(context.Context, string) error
|
||||
}
|
||||
|
||||
func (api *API) chatAgentModelOverrideSiteConfig(
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (chatAgentModelOverrideSiteConfig, error) {
|
||||
func (api *API) chatModelOverrideSiteConfig(
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) (chatModelOverrideSiteConfig, error) {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
return chatAgentModelOverrideSiteConfig{
|
||||
case codersdk.ChatModelOverrideContextGeneral:
|
||||
return chatModelOverrideSiteConfig{
|
||||
label: "general",
|
||||
getter: api.Database.GetChatGeneralModelOverride,
|
||||
upsert: api.Database.UpsertChatGeneralModelOverride,
|
||||
}, nil
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
return chatAgentModelOverrideSiteConfig{
|
||||
case codersdk.ChatModelOverrideContextExplore:
|
||||
return chatModelOverrideSiteConfig{
|
||||
label: "explore",
|
||||
getter: api.Database.GetChatExploreModelOverride,
|
||||
upsert: api.Database.UpsertChatExploreModelOverride,
|
||||
}, nil
|
||||
case codersdk.ChatModelOverrideContextTitleGeneration:
|
||||
return chatModelOverrideSiteConfig{
|
||||
label: "title generation",
|
||||
getter: api.Database.GetChatTitleGenerationModelOverride,
|
||||
upsert: api.Database.UpsertChatTitleGenerationModelOverride,
|
||||
}, nil
|
||||
default:
|
||||
return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf(
|
||||
"unknown chat agent model override context %q",
|
||||
return chatModelOverrideSiteConfig{}, xerrors.Errorf(
|
||||
"unknown chat model override context %q",
|
||||
overrideContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) getChatAgentModelOverrideConfig(
|
||||
func (api *API) readChatModelOverrideConfig(
|
||||
ctx context.Context,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (*uuid.UUID, bool, error) {
|
||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) (*uuid.UUID, bool, string, error) {
|
||||
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, false, "", err
|
||||
}
|
||||
return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter)
|
||||
id, isMalformed, err := api.getChatModelOverrideConfig(ctx, siteConfig.label, siteConfig.getter)
|
||||
return id, isMalformed, siteConfig.label, err
|
||||
}
|
||||
|
||||
func (api *API) upsertChatAgentModelOverrideConfig(
|
||||
func (api *API) upsertChatModelOverrideConfig(
|
||||
ctx context.Context,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
modelConfigID *uuid.UUID,
|
||||
) error {
|
||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
||||
) (string, error) {
|
||||
siteConfig, err := api.chatModelOverrideSiteConfig(overrideContext)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
||||
return siteConfig.label, siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
@@ -3941,27 +3951,27 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func readChatAgentModelOverrideContext(
|
||||
func readChatModelOverrideContext(
|
||||
rw http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) (codersdk.ChatAgentModelOverrideContext, bool) {
|
||||
) (codersdk.ChatModelOverrideContext, bool) {
|
||||
ctx := r.Context()
|
||||
rawContext := chi.URLParam(r, "context")
|
||||
overrideContext, err := parseChatAgentModelOverrideContext(rawContext)
|
||||
overrideContext, err := parseChatModelOverrideContext(rawContext)
|
||||
if err == nil {
|
||||
return overrideContext, true
|
||||
}
|
||||
validContextValues := make(
|
||||
[]string,
|
||||
0,
|
||||
len(codersdk.AllChatAgentModelOverrideContexts()),
|
||||
len(codersdk.AllChatModelOverrideContexts()),
|
||||
)
|
||||
for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() {
|
||||
for _, overrideContext := range codersdk.AllChatModelOverrideContexts() {
|
||||
validContextValues = append(validContextValues, string(overrideContext))
|
||||
}
|
||||
validContexts := strings.Join(validContextValues, ", ")
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat agent model override context.",
|
||||
Message: "Invalid chat model override context.",
|
||||
Detail: fmt.Sprintf(
|
||||
"Expected one of %s. Got %q.",
|
||||
validContexts,
|
||||
@@ -3974,27 +3984,30 @@ func readChatAgentModelOverrideContext(
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) getChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
||||
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext)
|
||||
modelConfigID, isMalformed, label, err := api.readChatModelOverrideConfig(ctx, overrideContext)
|
||||
if err != nil {
|
||||
if label == "" {
|
||||
label = string(overrideContext)
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext),
|
||||
Message: fmt.Sprintf("Internal error fetching %s model override.", label),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := codersdk.ChatAgentModelOverrideResponse{
|
||||
resp := codersdk.ChatModelOverrideResponse{
|
||||
Context: overrideContext,
|
||||
ModelConfigID: formatChatModelOverride(modelConfigID),
|
||||
IsMalformed: isMalformed,
|
||||
@@ -4004,18 +4017,18 @@ func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) putChatModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
||||
overrideContext, ok := readChatModelOverrideContext(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatAgentModelOverrideRequest
|
||||
var req codersdk.UpdateChatModelOverrideRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
@@ -4035,9 +4048,13 @@ func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil {
|
||||
label, err := api.upsertChatModelOverrideConfig(ctx, overrideContext, modelConfigID)
|
||||
if err != nil {
|
||||
if label == "" {
|
||||
label = string(overrideContext)
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext),
|
||||
Message: fmt.Sprintf("Internal error updating %s model override.", label),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
|
||||
+27
-17
@@ -10061,28 +10061,28 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type overrideResponse struct {
|
||||
context codersdk.ChatAgentModelOverrideContext
|
||||
context codersdk.ChatModelOverrideContext
|
||||
modelConfigID string
|
||||
isMalformed bool
|
||||
}
|
||||
|
||||
type settingTest struct {
|
||||
name string
|
||||
context codersdk.ChatAgentModelOverrideContext
|
||||
context codersdk.ChatModelOverrideContext
|
||||
dbGet func(context.Context, database.Store) (string, error)
|
||||
dbUpsert func(context.Context, database.Store, string) error
|
||||
}
|
||||
|
||||
settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string {
|
||||
return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext)
|
||||
settingPath := func(overrideContext codersdk.ChatModelOverrideContext) string {
|
||||
return "/api/experimental/chats/config/model-override/" + string(overrideContext)
|
||||
}
|
||||
|
||||
getOverride := func(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) (overrideResponse, error) {
|
||||
resp, err := client.GetChatAgentModelOverride(ctx, overrideContext)
|
||||
resp, err := client.GetChatModelOverride(ctx, overrideContext)
|
||||
if err != nil {
|
||||
return overrideResponse{}, err
|
||||
}
|
||||
@@ -10096,20 +10096,20 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
putOverride := func(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
modelConfigID string,
|
||||
) error {
|
||||
return client.UpdateChatAgentModelOverride(
|
||||
return client.UpdateChatModelOverride(
|
||||
ctx,
|
||||
overrideContext,
|
||||
codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID},
|
||||
codersdk.UpdateChatModelOverrideRequest{ModelConfigID: modelConfigID},
|
||||
)
|
||||
}
|
||||
|
||||
settings := []settingTest{
|
||||
{
|
||||
name: "General",
|
||||
context: codersdk.ChatAgentModelOverrideContextGeneral,
|
||||
context: codersdk.ChatModelOverrideContextGeneral,
|
||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||
return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||
},
|
||||
@@ -10119,7 +10119,7 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Explore",
|
||||
context: codersdk.ChatAgentModelOverrideContextExplore,
|
||||
context: codersdk.ChatModelOverrideContextExplore,
|
||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||
return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||
},
|
||||
@@ -10127,6 +10127,16 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TitleGeneration",
|
||||
context: codersdk.ChatModelOverrideContextTitleGeneration,
|
||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||
return db.GetChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||
},
|
||||
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
|
||||
return db.UpsertChatTitleGenerationModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, setting := range settings {
|
||||
@@ -10265,23 +10275,23 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
||||
unknownContext := codersdk.ChatModelOverrideContext("not-a-context")
|
||||
|
||||
_, err := getOverride(ctx, adminClient, unknownContext)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
||||
require.Equal(t, "Invalid chat model override context.", sdkErr.Message)
|
||||
require.Equal(
|
||||
t,
|
||||
`Expected one of general, explore. Got "not-a-context".`,
|
||||
`Expected one of general, explore, title_generation. Got "not-a-context".`,
|
||||
sdkErr.Detail,
|
||||
)
|
||||
|
||||
err = putOverride(ctx, adminClient, unknownContext, "")
|
||||
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
||||
require.Equal(t, "Invalid chat model override context.", sdkErr.Message)
|
||||
require.Equal(
|
||||
t,
|
||||
`Expected one of general, explore. Got "not-a-context".`,
|
||||
`Expected one of general, explore, title_generation. Got "not-a-context".`,
|
||||
sdkErr.Detail,
|
||||
)
|
||||
})
|
||||
@@ -10293,7 +10303,7 @@ func TestChatModelOverrides(t *testing.T) {
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
||||
unknownContext := codersdk.ChatModelOverrideContext("not-a-context")
|
||||
|
||||
_, err := getOverride(ctx, memberClient, unknownContext)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
|
||||
@@ -3110,6 +3110,26 @@ func (p *Server) resolveManualTitleModel(
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
||||
overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
ctx,
|
||||
chat,
|
||||
keys,
|
||||
)
|
||||
if overrideErr != nil {
|
||||
if overrideSet {
|
||||
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||
"resolve manual title generation model override: %w",
|
||||
overrideErr,
|
||||
)
|
||||
}
|
||||
p.logger.Debug(ctx, "failed to resolve title generation model override for manual title",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(overrideErr),
|
||||
)
|
||||
} else if overrideSet {
|
||||
return overrideModel, overrideConfig, nil
|
||||
}
|
||||
|
||||
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to list manual title model configs",
|
||||
|
||||
@@ -636,6 +636,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
).Return(nil, nil)
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
||||
|
||||
gomock.InOrder(
|
||||
@@ -799,6 +800,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
).Return(nil, nil)
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
||||
|
||||
gomock.InOrder(
|
||||
|
||||
+80
-31
@@ -106,9 +106,10 @@ type generatedTitle struct {
|
||||
// maybeGenerateChatTitle generates an AI title for the chat when
|
||||
// appropriate (first user message, no assistant reply yet, and the
|
||||
// current title is either empty or still the fallback truncation).
|
||||
// It tries cheap, fast models first and falls back to the user's
|
||||
// chat model. It is a best-effort operation that logs and swallows
|
||||
// errors.
|
||||
// It uses the configured title generation model override when set.
|
||||
// Otherwise, it tries cheap, fast models first and falls back to the
|
||||
// user's chat model. It is a best-effort operation that logs and
|
||||
// swallows errors.
|
||||
func (p *Server) maybeGenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
@@ -130,28 +131,58 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Build candidate list: preferred lightweight models first,
|
||||
// then the user's chat model as last resort.
|
||||
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
titleCtx,
|
||||
chat,
|
||||
keys,
|
||||
)
|
||||
if overrideErr != nil {
|
||||
if overrideSet {
|
||||
logger.Warn(ctx, "title generation model override unavailable, skipping title generation",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("override_context", titleGenerationOverrideContext),
|
||||
slog.Error(overrideErr),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Debug(ctx, "failed to resolve title generation model override",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("override_context", titleGenerationOverrideContext),
|
||||
slog.Error(overrideErr),
|
||||
)
|
||||
}
|
||||
|
||||
var candidates []shortTextCandidate
|
||||
if overrideSet {
|
||||
candidates = []shortTextCandidate{{
|
||||
provider: overrideConfig.Provider,
|
||||
model: overrideConfig.Model,
|
||||
lm: overrideModel,
|
||||
}}
|
||||
} else {
|
||||
// Build candidate list: preferred lightweight models first,
|
||||
// then the user's chat model as last resort.
|
||||
candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
}
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
|
||||
var historyTipMessageID int64
|
||||
if len(messages) > 0 {
|
||||
@@ -197,10 +228,20 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
finishDebugRun(err)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
logger.Debug(ctx, "title model candidate failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
if overrideSet {
|
||||
logger.Warn(ctx, "title model candidate failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("override_context", titleGenerationOverrideContext),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
logger.Debug(ctx, "title model candidate failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if title == "" || title == chat.Title {
|
||||
@@ -225,10 +266,18 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
logger.Debug(ctx, "all title model candidates failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(lastErr),
|
||||
)
|
||||
if overrideSet {
|
||||
logger.Warn(ctx, "all title model candidates failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("override_context", titleGenerationOverrideContext),
|
||||
slog.Error(lastErr),
|
||||
)
|
||||
} else {
|
||||
logger.Debug(ctx, "all title model candidates failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(lastErr),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -104,12 +104,12 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
}
|
||||
|
||||
func subagentModelOverrideLogLabel(
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) string {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
case codersdk.ChatModelOverrideContextGeneral:
|
||||
return "general delegated child"
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
case codersdk.ChatModelOverrideContextExplore:
|
||||
return "explore"
|
||||
default:
|
||||
return string(overrideContext)
|
||||
@@ -119,16 +119,16 @@ func subagentModelOverrideLogLabel(
|
||||
func readSubagentModelOverride(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) (string, error) {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
case codersdk.ChatModelOverrideContextGeneral:
|
||||
return db.GetChatGeneralModelOverride(ctx)
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
case codersdk.ChatModelOverrideContextExplore:
|
||||
return db.GetChatExploreModelOverride(ctx)
|
||||
default:
|
||||
return "", xerrors.Errorf(
|
||||
"unknown subagent model override context %q",
|
||||
"unsupported subagent model override context %q",
|
||||
overrideContext,
|
||||
)
|
||||
}
|
||||
@@ -167,6 +167,20 @@ func enabledProviderContainsName(
|
||||
return false
|
||||
}
|
||||
|
||||
type modelOverrideFailureMode int
|
||||
|
||||
const (
|
||||
modelOverrideFailureModeSoft modelOverrideFailureMode = iota
|
||||
modelOverrideFailureModeHard
|
||||
)
|
||||
|
||||
func modelOverrideErrorLabel(overrideContext string) string {
|
||||
return strings.ReplaceAll(overrideContext, "_", " ")
|
||||
}
|
||||
|
||||
// resolveConfiguredModelOverride returns ok when a usable override is
|
||||
// resolved. In hard failure mode, ok is also true for configured but unusable
|
||||
// overrides so callers can distinguish them from unset or malformed values.
|
||||
func (p *Server) resolveConfiguredModelOverride(
|
||||
ctx context.Context,
|
||||
overrideContext string,
|
||||
@@ -174,6 +188,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
||||
ownerID uuid.UUID,
|
||||
resolveModelConfig modelOverrideConfigResolver,
|
||||
resolveProviderKeys modelOverrideProviderKeysResolver,
|
||||
failureMode modelOverrideFailureMode,
|
||||
) (database.ChatModelConfig, bool, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
@@ -189,13 +204,40 @@ func (p *Server) resolveConfiguredModelOverride(
|
||||
)
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
|
||||
modelConfig, providerName, err := resolveModelConfig(
|
||||
ctx,
|
||||
configuredModelConfigID,
|
||||
)
|
||||
if err != nil {
|
||||
if failureMode == modelOverrideFailureModeHard {
|
||||
label := modelOverrideErrorLabel(overrideContext)
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||
"%s model override is unavailable: %s",
|
||||
label,
|
||||
configuredModelConfigID,
|
||||
)
|
||||
case errors.Is(err, errInvalidModelOverrideMetadata):
|
||||
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||
"%s model override metadata is invalid for %s: %w",
|
||||
label,
|
||||
configuredModelConfigID,
|
||||
err,
|
||||
)
|
||||
default:
|
||||
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||
"resolve %s model override %s: %w",
|
||||
label,
|
||||
configuredModelConfigID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case xerrors.Is(err, sql.ErrNoRows):
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
p.logger.Info(ctx,
|
||||
"model override is unavailable, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
@@ -218,6 +260,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
||||
}
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
|
||||
providerKeys, err := resolveProviderKeys(ctx, ownerID)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, false, xerrors.Errorf(
|
||||
@@ -228,6 +271,14 @@ func (p *Server) resolveConfiguredModelOverride(
|
||||
if providerKeys.APIKey(providerName) == "" &&
|
||||
!(chatprovider.ProviderAllowsAmbientCredentials(providerName) &&
|
||||
providerKeys.HasProvider(providerName)) {
|
||||
if failureMode == modelOverrideFailureModeHard {
|
||||
return database.ChatModelConfig{}, true, xerrors.Errorf(
|
||||
"%s model override credentials are unavailable for provider %q",
|
||||
modelOverrideErrorLabel(overrideContext),
|
||||
providerName,
|
||||
)
|
||||
}
|
||||
|
||||
p.logger.Info(ctx,
|
||||
"model override credentials are unavailable, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
@@ -242,7 +293,7 @@ func (p *Server) resolveConfiguredModelOverride(
|
||||
func (p *Server) resolveSubagentModelConfigID(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
overrideContext codersdk.ChatModelOverrideContext,
|
||||
) (uuid.UUID, error) {
|
||||
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
@@ -261,6 +312,7 @@ func (p *Server) resolveSubagentModelConfigID(
|
||||
ownerID,
|
||||
p.resolveModelConfigAndNormalizedProvider,
|
||||
p.resolveUserProviderAPIKeys,
|
||||
modelOverrideFailureModeSoft,
|
||||
)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
|
||||
@@ -48,7 +48,7 @@ func allSubagentDefinitions() []subagentDefinition {
|
||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||
ctx,
|
||||
parent.OwnerID,
|
||||
codersdk.ChatAgentModelOverrideContextGeneral,
|
||||
codersdk.ChatModelOverrideContextGeneral,
|
||||
)
|
||||
if err != nil {
|
||||
return childSubagentChatOptions{}, err
|
||||
@@ -67,7 +67,7 @@ func allSubagentDefinitions() []subagentDefinition {
|
||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||
ctx,
|
||||
turnParent.OwnerID,
|
||||
codersdk.ChatAgentModelOverrideContextExplore,
|
||||
codersdk.ChatModelOverrideContextExplore,
|
||||
)
|
||||
if err != nil {
|
||||
return childSubagentChatOptions{}, err
|
||||
|
||||
@@ -834,6 +834,7 @@ func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider(
|
||||
ByProvider: map[string]string{"bedrock": ""},
|
||||
}, nil
|
||||
},
|
||||
modelOverrideFailureModeSoft,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
const titleGenerationOverrideContext = "title_generation"
|
||||
|
||||
func readTitleGenerationModelOverride(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
) (string, error) {
|
||||
//nolint:gocritic // Chatd is internal, not a user, so this read uses AsChatd.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
raw, err := db.GetChatTitleGenerationModelOverride(chatdCtx)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf(
|
||||
"get chat title generation model override: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// resolveTitleGenerationModelOverride resolves the deployment-wide title
|
||||
// generation model override. It returns four values:
|
||||
//
|
||||
// - modelConfig and model: populated only on success.
|
||||
// - overrideSet: true when the admin configured a non-empty override,
|
||||
// regardless of whether resolution succeeded. Callers MUST always check
|
||||
// err first; overrideSet alone does not imply the model is usable.
|
||||
// - err: non-nil when resolution failed. DB read failure returns
|
||||
// (zero, nil, false, err). With overrideSet=true, the override is
|
||||
// configured but unusable (deleted model, missing credentials, etc.) and
|
||||
// callers should treat this as a hard failure for explicit-override
|
||||
// semantics, not a soft fallback.
|
||||
//
|
||||
// When the override is unset or stored as malformed, the function returns
|
||||
// (zero, nil, false, nil) so callers can fall back to default behavior.
|
||||
func (p *Server) resolveTitleGenerationModelOverride(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (database.ChatModelConfig, fantasy.LanguageModel, bool, error) {
|
||||
raw, err := readTitleGenerationModelOverride(ctx, p.db)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, false, xerrors.Errorf(
|
||||
"read title generation model override: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
modelConfig, overrideSet, err := p.resolveConfiguredModelOverride(
|
||||
ctx,
|
||||
titleGenerationOverrideContext,
|
||||
raw,
|
||||
chat.OwnerID,
|
||||
p.resolveModelConfigAndNormalizedProvider,
|
||||
func(context.Context, uuid.UUID) (chatprovider.ProviderAPIKeys, error) {
|
||||
return keys, nil
|
||||
},
|
||||
modelOverrideFailureModeHard,
|
||||
)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, overrideSet, err
|
||||
}
|
||||
if !overrideSet {
|
||||
return database.ChatModelConfig{}, nil, false, nil
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, true, xerrors.Errorf(
|
||||
"create title generation model override: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
if model == nil {
|
||||
return database.ChatModelConfig{}, nil, true, xerrors.Errorf(
|
||||
"create title generation model override returned nil",
|
||||
)
|
||||
}
|
||||
|
||||
return modelConfig, model, true, nil
|
||||
}
|
||||
@@ -0,0 +1,559 @@
|
||||
package chatd //nolint:testpackage // Tests internal title override helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("uses preferred model before fallback", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
wantTitle := "Preferred title"
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
requestCount.Add(1)
|
||||
require.Equal(t, preferredTitleModels[1].model, req.Model)
|
||||
return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`)
|
||||
})
|
||||
keys := titleOverrideOpenAIKeys(serverURL)
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
t.Fatal("fallback model should not be called when preferred model works")
|
||||
return nil, xerrors.New("unexpected fallback model call")
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
keys,
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), requestCount.Load())
|
||||
gotTitle, ok := generated.Load()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, wantTitle, gotTitle)
|
||||
})
|
||||
|
||||
t.Run("falls back to chat model when preferred models are unavailable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
wantTitle := "Fallback title"
|
||||
|
||||
var fallbackCalls atomic.Int32
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
fallbackCalls.Add(1)
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"title": wantTitle},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||
gotTitle, ok := generated.Load()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, wantTitle, gotTitle)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
wantTitle := "Fallback title"
|
||||
|
||||
var fallbackCalls atomic.Int32
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
fallbackCalls.Add(1)
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"title": wantTitle},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone)
|
||||
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||
gotTitle, ok := generated.Load()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, wantTitle, gotTitle)
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
wantTitle := "Fallback title"
|
||||
|
||||
var fallbackCalls atomic.Int32
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
fallbackCalls.Add(1)
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"title": wantTitle},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("not-a-uuid", nil)
|
||||
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), fallbackCalls.Load())
|
||||
gotTitle, ok := generated.Load()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, wantTitle, gotTitle)
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
wantTitle := "Override title"
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
requestCount.Add(1)
|
||||
require.Equal(t, overrideConfig.Model, req.Model)
|
||||
return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`)
|
||||
})
|
||||
keys := titleOverrideOpenAIKeys(serverURL)
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
t.Fatal("fallback model should not be called when override is usable")
|
||||
return nil, xerrors.New("unexpected fallback model call")
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
db.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
}).Return(chatWithTitle(chat, wantTitle), nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
keys,
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), requestCount.Load())
|
||||
gotTitle, ok := generated.Load()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, wantTitle, gotTitle)
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", false)
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
t.Fatal("fallback model should not be called when override is unusable")
|
||||
return nil, xerrors.New("unexpected fallback model call")
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
_, ok := generated.Load()
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
requestCount.Add(1)
|
||||
require.Equal(t, overrideConfig.Model, req.Model)
|
||||
return chattest.OpenAINonStreamingResponse(`{"title":""}`)
|
||||
})
|
||||
keys := titleOverrideOpenAIKeys(serverURL)
|
||||
fallbackModel := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
t.Fatal("fallback model should not be called after override call failure")
|
||||
return nil, xerrors.New("unexpected fallback model call")
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
|
||||
generated := &generatedChatTitle{}
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.maybeGenerateChatTitle(
|
||||
ctx,
|
||||
chat,
|
||||
messages,
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
keys,
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Equal(t, int32(1), requestCount.Load())
|
||||
_, ok := generated.Load()
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||
preferredConfig := database.ChatModelConfig{
|
||||
ID: uuid.New(),
|
||||
Provider: preferredTitleModels[1].provider,
|
||||
Model: preferredTitleModels[1].model,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{
|
||||
{Provider: "openai", Model: "gpt-4.1", Enabled: true},
|
||||
preferredConfig,
|
||||
}, nil)
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, err := server.resolveManualTitleModel(
|
||||
ctx,
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
require.Equal(t, preferredConfig, gotConfig)
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||
preferredConfig := database.ChatModelConfig{
|
||||
ID: uuid.New(),
|
||||
Provider: preferredTitleModels[1].provider,
|
||||
Model: preferredTitleModels[1].model,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", sql.ErrConnDone)
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{
|
||||
{Provider: "openai", Model: "gpt-4.1", Enabled: true},
|
||||
preferredConfig,
|
||||
}, nil)
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, err := server.resolveManualTitleModel(
|
||||
ctx,
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
require.Equal(t, preferredConfig, gotConfig)
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, err := server.resolveManualTitleModel(
|
||||
ctx,
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
require.Equal(t, overrideConfig, gotConfig)
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil)
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, err := server.resolveManualTitleModel(
|
||||
ctx,
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||
require.ErrorContains(t, err, "credentials are unavailable")
|
||||
require.Nil(t, model)
|
||||
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, _ := titleOverrideTestChatAndMessages(t)
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", false)
|
||||
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
model, gotConfig, err := server.resolveManualTitleModel(
|
||||
ctx,
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||
require.ErrorContains(t, err, "title generation model override is unavailable")
|
||||
require.Nil(t, model)
|
||||
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||
}
|
||||
|
||||
func titleOverrideTestChatAndMessages(t *testing.T) (database.Chat, []database.ChatMessage) {
|
||||
t.Helper()
|
||||
|
||||
userPrompt := "review pull request 123 and fix comments"
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
OwnerID: uuid.New(),
|
||||
Title: fallbackChatTitle(userPrompt),
|
||||
}
|
||||
message := mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText(userPrompt),
|
||||
)
|
||||
message.ID = 1
|
||||
return chat, []database.ChatMessage{message}
|
||||
}
|
||||
|
||||
func titleOverrideTestServer(db database.Store, logger slog.Logger) *Server {
|
||||
return &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
configCache: newChatConfigCache(context.Background(), db, quartz.NewReal()),
|
||||
}
|
||||
}
|
||||
|
||||
func titleOverrideModelConfig(model string, enabled bool) database.ChatModelConfig {
|
||||
return database.ChatModelConfig{
|
||||
ID: uuid.New(),
|
||||
Provider: "openai",
|
||||
Model: model,
|
||||
Enabled: enabled,
|
||||
}
|
||||
}
|
||||
|
||||
func titleOverrideOpenAIKeys(serverURL string) chatprovider.ProviderAPIKeys {
|
||||
return chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
"openai": "test-key",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
"openai": serverURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func chatWithTitle(chat database.Chat, title string) database.Chat {
|
||||
chat.Title = title
|
||||
return chat
|
||||
}
|
||||
+36
-33
@@ -562,45 +562,48 @@ type UpdateChatPlanModeInstructionsRequest struct {
|
||||
PlanModeInstructions string `json:"plan_mode_instructions"`
|
||||
}
|
||||
|
||||
// ChatAgentModelOverrideContext identifies which chat or subagent context
|
||||
// a deployment override applies to.
|
||||
type ChatAgentModelOverrideContext string
|
||||
// ChatModelOverrideContext identifies which chat model override context a
|
||||
// deployment override applies to.
|
||||
type ChatModelOverrideContext string
|
||||
|
||||
const (
|
||||
ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general"
|
||||
ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore"
|
||||
ChatModelOverrideContextGeneral ChatModelOverrideContext = "general"
|
||||
ChatModelOverrideContextExplore ChatModelOverrideContext = "explore"
|
||||
ChatModelOverrideContextTitleGeneration ChatModelOverrideContext = "title_generation"
|
||||
)
|
||||
|
||||
// Valid reports whether the override context is one of the supported values.
|
||||
func (c ChatAgentModelOverrideContext) Valid() bool {
|
||||
func (c ChatModelOverrideContext) Valid() bool {
|
||||
switch c {
|
||||
case ChatAgentModelOverrideContextGeneral,
|
||||
ChatAgentModelOverrideContextExplore:
|
||||
case ChatModelOverrideContextGeneral,
|
||||
ChatModelOverrideContextExplore,
|
||||
ChatModelOverrideContextTitleGeneration:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AllChatAgentModelOverrideContexts returns all supported override contexts.
|
||||
func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext {
|
||||
return []ChatAgentModelOverrideContext{
|
||||
ChatAgentModelOverrideContextGeneral,
|
||||
ChatAgentModelOverrideContextExplore,
|
||||
// AllChatModelOverrideContexts returns all supported override contexts.
|
||||
func AllChatModelOverrideContexts() []ChatModelOverrideContext {
|
||||
return []ChatModelOverrideContext{
|
||||
ChatModelOverrideContextGeneral,
|
||||
ChatModelOverrideContextExplore,
|
||||
ChatModelOverrideContextTitleGeneration,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatAgentModelOverrideResponse is the response body for the chat agent
|
||||
// model override configuration endpoint.
|
||||
type ChatAgentModelOverrideResponse struct {
|
||||
Context ChatAgentModelOverrideContext `json:"context"`
|
||||
ModelConfigID string `json:"model_config_id"`
|
||||
IsMalformed bool `json:"is_malformed"`
|
||||
// ChatModelOverrideResponse is the response body for the chat model override
|
||||
// configuration endpoint.
|
||||
type ChatModelOverrideResponse struct {
|
||||
Context ChatModelOverrideContext `json:"context"`
|
||||
ModelConfigID string `json:"model_config_id"`
|
||||
IsMalformed bool `json:"is_malformed"`
|
||||
}
|
||||
|
||||
// UpdateChatAgentModelOverrideRequest is the request body for updating the
|
||||
// chat agent model override configuration endpoint.
|
||||
type UpdateChatAgentModelOverrideRequest struct {
|
||||
// UpdateChatModelOverrideRequest is the request body for updating the chat
|
||||
// model override configuration endpoint.
|
||||
type UpdateChatModelOverrideRequest struct {
|
||||
ModelConfigID string `json:"model_config_id"`
|
||||
}
|
||||
|
||||
@@ -2098,30 +2101,30 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatAgentModelOverride returns the deployment-wide chat agent model
|
||||
// override for the requested context.
|
||||
func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) {
|
||||
// GetChatModelOverride returns the deployment-wide chat model override for
|
||||
// the requested context.
|
||||
func (c *ExperimentalClient) GetChatModelOverride(ctx context.Context, override ChatModelOverrideContext) (ChatModelOverrideResponse, error) {
|
||||
path := fmt.Sprintf(
|
||||
"/api/experimental/chats/config/agent-model-override/%s",
|
||||
"/api/experimental/chats/config/model-override/%s",
|
||||
url.PathEscape(string(override)),
|
||||
)
|
||||
res, err := c.Request(ctx, http.MethodGet, path, nil)
|
||||
if err != nil {
|
||||
return ChatAgentModelOverrideResponse{}, err
|
||||
return ChatModelOverrideResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res)
|
||||
return ChatModelOverrideResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatAgentModelOverrideResponse
|
||||
var resp ChatModelOverrideResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateChatAgentModelOverride updates the deployment-wide chat agent model
|
||||
// override for the requested context.
|
||||
func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error {
|
||||
// UpdateChatModelOverride updates the deployment-wide chat model override for
|
||||
// the requested context.
|
||||
func (c *ExperimentalClient) UpdateChatModelOverride(ctx context.Context, override ChatModelOverrideContext, req UpdateChatModelOverrideRequest) error {
|
||||
path := fmt.Sprintf(
|
||||
"/api/experimental/chats/config/agent-model-override/%s",
|
||||
"/api/experimental/chats/config/model-override/%s",
|
||||
url.PathEscape(string(override)),
|
||||
)
|
||||
res, err := c.Request(ctx, http.MethodPut, path, req)
|
||||
|
||||
+10
-11
@@ -3262,22 +3262,21 @@ class ExperimentalApiMethods {
|
||||
);
|
||||
};
|
||||
|
||||
getChatAgentModelOverride = async (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
): Promise<TypesGen.ChatAgentModelOverrideResponse> => {
|
||||
const response =
|
||||
await this.axios.get<TypesGen.ChatAgentModelOverrideResponse>(
|
||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
||||
);
|
||||
getChatModelOverride = async (
|
||||
context: TypesGen.ChatModelOverrideContext,
|
||||
): Promise<TypesGen.ChatModelOverrideResponse> => {
|
||||
const response = await this.axios.get<TypesGen.ChatModelOverrideResponse>(
|
||||
`/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
updateChatAgentModelOverride = async (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
||||
updateChatModelOverride = async (
|
||||
context: TypesGen.ChatModelOverrideContext,
|
||||
req: TypesGen.UpdateChatModelOverrideRequest,
|
||||
): Promise<void> => {
|
||||
await this.axios.put(
|
||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
||||
`/api/experimental/chats/config/model-override/${encodeURIComponent(context)}`,
|
||||
req,
|
||||
);
|
||||
};
|
||||
|
||||
Generated
+32
-28
@@ -1320,25 +1320,6 @@ export interface Chat {
|
||||
readonly children: readonly Chat[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatAgentModelOverrideContext = "explore" | "general";
|
||||
|
||||
export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [
|
||||
"explore",
|
||||
"general",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatAgentModelOverrideResponse is the response body for the chat agent
|
||||
* model override configuration endpoint.
|
||||
*/
|
||||
export interface ChatAgentModelOverrideResponse {
|
||||
readonly context: ChatAgentModelOverrideContext;
|
||||
readonly model_config_id: string;
|
||||
readonly is_malformed: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatAutoArchiveDaysResponse contains the current chat auto-archive setting.
|
||||
@@ -2095,6 +2076,29 @@ export interface ChatModelOpenRouterProviderOptions {
|
||||
readonly provider?: ChatModelOpenRouterProvider;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatModelOverrideContext =
|
||||
| "explore"
|
||||
| "general"
|
||||
| "title_generation";
|
||||
|
||||
export const ChatModelOverrideContexts: ChatModelOverrideContext[] = [
|
||||
"explore",
|
||||
"general",
|
||||
"title_generation",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatModelOverrideResponse is the response body for the chat model override
|
||||
* configuration endpoint.
|
||||
*/
|
||||
export interface ChatModelOverrideResponse {
|
||||
readonly context: ChatModelOverrideContext;
|
||||
readonly model_config_id: string;
|
||||
readonly is_malformed: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatModelProvider represents provider availability and model results.
|
||||
@@ -7804,15 +7808,6 @@ export interface UpdateAppearanceConfig {
|
||||
readonly announcement_banners: readonly BannerConfig[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatAgentModelOverrideRequest is the request body for updating the
|
||||
* chat agent model override configuration endpoint.
|
||||
*/
|
||||
export interface UpdateChatAgentModelOverrideRequest {
|
||||
readonly model_config_id: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatAutoArchiveDaysRequest is a request to update the chat
|
||||
@@ -7854,6 +7849,15 @@ export interface UpdateChatModelConfigRequest {
|
||||
readonly model_config?: ChatModelCallConfig;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatModelOverrideRequest is the request body for updating the chat
|
||||
* model override configuration endpoint.
|
||||
*/
|
||||
export interface UpdateChatModelOverrideRequest {
|
||||
readonly model_config_id: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatPlanModeInstructionsRequest is the request body for
|
||||
|
||||
@@ -12,31 +12,30 @@ import { useAuthenticated } from "#/hooks/useAuthenticated";
|
||||
import { RequirePermission } from "#/modules/permissions/RequirePermission";
|
||||
import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView";
|
||||
|
||||
const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
||||
"general";
|
||||
const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
||||
"explore";
|
||||
const generalOverrideContext: TypesGen.ChatModelOverrideContext = "general";
|
||||
const exploreOverrideContext: TypesGen.ChatModelOverrideContext = "explore";
|
||||
const titleGenerationOverrideContext: TypesGen.ChatModelOverrideContext =
|
||||
"title_generation";
|
||||
|
||||
const chatAgentModelOverrideKey = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
) => ["chat-agent-model-override", context] as const;
|
||||
const chatModelOverrideKey = (context: TypesGen.ChatModelOverrideContext) =>
|
||||
["chat-model-override", context] as const;
|
||||
|
||||
const chatAgentModelOverrideQuery = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
const chatModelOverrideQuery = (
|
||||
context: TypesGen.ChatModelOverrideContext,
|
||||
) => ({
|
||||
queryKey: chatAgentModelOverrideKey(context),
|
||||
queryFn: () => API.experimental.getChatAgentModelOverride(context),
|
||||
queryKey: chatModelOverrideKey(context),
|
||||
queryFn: () => API.experimental.getChatModelOverride(context),
|
||||
});
|
||||
|
||||
const updateChatAgentModelOverrideMutation = (
|
||||
const updateChatModelOverrideMutation = (
|
||||
queryClient: QueryClient,
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
context: TypesGen.ChatModelOverrideContext,
|
||||
) => ({
|
||||
mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) =>
|
||||
API.experimental.updateChatAgentModelOverride(context, req),
|
||||
mutationFn: (req: TypesGen.UpdateChatModelOverrideRequest) =>
|
||||
API.experimental.updateChatModelOverride(context, req),
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: chatAgentModelOverrideKey(context),
|
||||
queryKey: chatModelOverrideKey(context),
|
||||
exact: true,
|
||||
});
|
||||
},
|
||||
@@ -48,25 +47,36 @@ const AgentSettingsAgentsPage: FC = () => {
|
||||
const canEditDeploymentConfig = permissions.editDeploymentConfig;
|
||||
|
||||
const generalModelOverrideQuery = useQuery({
|
||||
...chatAgentModelOverrideQuery(generalOverrideContext),
|
||||
...chatModelOverrideQuery(generalOverrideContext),
|
||||
enabled: canEditDeploymentConfig,
|
||||
});
|
||||
const exploreModelOverrideQuery = useQuery({
|
||||
...chatAgentModelOverrideQuery(exploreOverrideContext),
|
||||
...chatModelOverrideQuery(exploreOverrideContext),
|
||||
enabled: canEditDeploymentConfig,
|
||||
});
|
||||
const titleGenerationModelQuery = useQuery({
|
||||
...chatModelOverrideQuery(titleGenerationOverrideContext),
|
||||
enabled: canEditDeploymentConfig,
|
||||
});
|
||||
const modelConfigsQuery = useQuery(chatModelConfigs());
|
||||
const saveGeneralModelOverrideMutation = useMutation(
|
||||
updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext),
|
||||
updateChatModelOverrideMutation(queryClient, generalOverrideContext),
|
||||
);
|
||||
const saveTitleGenerationModelMutation = useMutation(
|
||||
updateChatModelOverrideMutation(
|
||||
queryClient,
|
||||
titleGenerationOverrideContext,
|
||||
),
|
||||
);
|
||||
const saveExploreModelOverrideMutation = useMutation(
|
||||
updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext),
|
||||
updateChatModelOverrideMutation(queryClient, exploreOverrideContext),
|
||||
);
|
||||
|
||||
return (
|
||||
<RequirePermission isFeatureVisible={canEditDeploymentConfig}>
|
||||
<AgentSettingsAgentsPageView
|
||||
generalModelOverrideData={generalModelOverrideQuery.data}
|
||||
titleGenerationModelOverrideData={titleGenerationModelQuery.data}
|
||||
exploreModelOverrideData={exploreModelOverrideQuery.data}
|
||||
modelConfigsData={modelConfigsQuery.data}
|
||||
modelConfigsError={modelConfigsQuery.error}
|
||||
@@ -78,6 +88,13 @@ const AgentSettingsAgentsPage: FC = () => {
|
||||
isSaveGeneralModelOverrideError={
|
||||
saveGeneralModelOverrideMutation.isError
|
||||
}
|
||||
onSaveTitleGenerationModel={saveTitleGenerationModelMutation.mutate}
|
||||
isSavingTitleGenerationModel={
|
||||
saveTitleGenerationModelMutation.isPending
|
||||
}
|
||||
isSaveTitleGenerationModelError={
|
||||
saveTitleGenerationModelMutation.isError
|
||||
}
|
||||
onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate}
|
||||
isSavingExploreModelOverride={
|
||||
saveExploreModelOverrideMutation.isPending
|
||||
|
||||
@@ -10,6 +10,8 @@ const OVERRIDE_MALFORMED_WARNING =
|
||||
"The saved override is malformed and is being treated as unset. Click Save to clear it.";
|
||||
const UNAVAILABLE_SAVED_MODEL_WARNING =
|
||||
"The saved model is no longer enabled and will be ignored until you choose a new override.";
|
||||
const TITLE_UNAVAILABLE_SAVED_MODEL_WARNING =
|
||||
"The selected model is currently unavailable. Title generation will be skipped until you choose another model or clear this setting.";
|
||||
|
||||
const buildModelConfig = (
|
||||
overrides: Partial<TypesGen.ChatModelConfig>,
|
||||
@@ -28,15 +30,20 @@ const buildModelConfig = (
|
||||
});
|
||||
|
||||
const buildOverrideData = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
overrides: Partial<TypesGen.ChatAgentModelOverrideResponse> = {},
|
||||
): TypesGen.ChatAgentModelOverrideResponse => ({
|
||||
context: TypesGen.ChatModelOverrideContext,
|
||||
overrides: Partial<TypesGen.ChatModelOverrideResponse> = {},
|
||||
): TypesGen.ChatModelOverrideResponse => ({
|
||||
context,
|
||||
model_config_id: "",
|
||||
is_malformed: false,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const buildTitleGenerationModelOverrideData = (
|
||||
overrides: Partial<TypesGen.ChatModelOverrideResponse> = {},
|
||||
): TypesGen.ChatModelOverrideResponse =>
|
||||
buildOverrideData("title_generation", overrides);
|
||||
|
||||
const generalModelConfig = buildModelConfig({
|
||||
id: "model-general-gpt-4.1-mini",
|
||||
display_name: "GPT 4.1 Mini",
|
||||
@@ -50,6 +57,13 @@ const claudeSonnetModelConfig = buildModelConfig({
|
||||
context_limit: 200_000,
|
||||
});
|
||||
|
||||
const titleModelConfig = buildModelConfig({
|
||||
id: "model-title-gpt-4o-mini",
|
||||
model: "gpt-4o-mini",
|
||||
display_name: "GPT 4o Mini",
|
||||
context_limit: 128_000,
|
||||
});
|
||||
|
||||
const exploreFallbackModelConfig = buildModelConfig({
|
||||
id: "model-explore-blank-display",
|
||||
provider: "anthropic",
|
||||
@@ -65,6 +79,14 @@ const generalDisabledModelConfig = buildModelConfig({
|
||||
enabled: false,
|
||||
});
|
||||
|
||||
const titleDisabledModelConfig = buildModelConfig({
|
||||
id: "model-title-disabled",
|
||||
model: "gpt-4o-mini-legacy",
|
||||
display_name: "GPT 4o Mini Legacy",
|
||||
enabled: false,
|
||||
context_limit: 128_000,
|
||||
});
|
||||
|
||||
const exploreDisabledModelConfig = buildModelConfig({
|
||||
id: "model-explore-disabled",
|
||||
provider: "anthropic",
|
||||
@@ -77,8 +99,10 @@ const exploreDisabledModelConfig = buildModelConfig({
|
||||
const allModelConfigs: TypesGen.ChatModelConfig[] = [
|
||||
generalModelConfig,
|
||||
claudeSonnetModelConfig,
|
||||
titleModelConfig,
|
||||
exploreFallbackModelConfig,
|
||||
generalDisabledModelConfig,
|
||||
titleDisabledModelConfig,
|
||||
exploreDisabledModelConfig,
|
||||
];
|
||||
|
||||
@@ -86,6 +110,7 @@ const makeArgs = (
|
||||
overrides: Partial<AgentSettingsAgentsPageViewProps> = {},
|
||||
): AgentSettingsAgentsPageViewProps => ({
|
||||
generalModelOverrideData: buildOverrideData("general"),
|
||||
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData(),
|
||||
exploreModelOverrideData: buildOverrideData("explore"),
|
||||
modelConfigsData: allModelConfigs,
|
||||
modelConfigsError: undefined,
|
||||
@@ -93,6 +118,9 @@ const makeArgs = (
|
||||
onSaveGeneralModelOverride: fn(),
|
||||
isSavingGeneralModelOverride: false,
|
||||
isSaveGeneralModelOverrideError: false,
|
||||
onSaveTitleGenerationModel: fn(),
|
||||
isSavingTitleGenerationModel: false,
|
||||
isSaveTitleGenerationModelError: false,
|
||||
onSaveExploreModelOverride: fn(),
|
||||
isSavingExploreModelOverride: false,
|
||||
isSaveExploreModelOverrideError: false,
|
||||
@@ -146,13 +174,28 @@ export const AllOverridesUnset: Story = {
|
||||
const headings = await canvas.findAllByRole("heading", { level: 3 });
|
||||
expect(headings.map((heading) => heading.textContent?.trim())).toEqual([
|
||||
"General model",
|
||||
"Title generation model",
|
||||
"Explore subagent model",
|
||||
]);
|
||||
await canvas.findByText(
|
||||
"Choose a model for generated chat titles. Leave unset to use Coder's default title algorithm, which currently tries fast title models for configured providers first, for example Claude Haiku, GPT-4o mini, and Gemini Flash, then falls back to the chat's current model. When a model is selected here, Coder uses only that model for title generation. Recommended title models are fast and low cost.",
|
||||
);
|
||||
|
||||
for (const headingName of ["General model", "Explore subagent model"]) {
|
||||
const unsetSections = [
|
||||
{ headingName: "General model", placeholder: "Use chat default" },
|
||||
{
|
||||
headingName: "Title generation model",
|
||||
placeholder: "Use title default",
|
||||
},
|
||||
{
|
||||
headingName: "Explore subagent model",
|
||||
placeholder: "Use chat default",
|
||||
},
|
||||
];
|
||||
for (const { headingName, placeholder } of unsetSections) {
|
||||
const section = await getSection(canvasElement, headingName);
|
||||
expect(
|
||||
within(section).getByRole("combobox", { name: "Use chat default" }),
|
||||
within(section).getByRole("combobox", { name: placeholder }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(section).getByRole("button", { name: "Save" }),
|
||||
@@ -166,12 +209,19 @@ export const EachOverrideSetToEnabledModel: Story = {
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
model_config_id: generalModelConfig.id,
|
||||
}),
|
||||
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||
model_config_id: titleModelConfig.id,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
model_config_id: exploreFallbackModelConfig.id,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const titleSection = await getSection(
|
||||
canvasElement,
|
||||
"Title generation model",
|
||||
);
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
@@ -183,6 +233,12 @@ export const EachOverrideSetToEnabledModel: Story = {
|
||||
}),
|
||||
).toHaveTextContent("claude-sonnet-4-20250514");
|
||||
|
||||
expect(
|
||||
within(titleSection).getByRole("combobox", {
|
||||
name: /gpt 4o mini/i,
|
||||
}),
|
||||
).toHaveTextContent("GPT 4o Mini");
|
||||
|
||||
await selectModelInSection(
|
||||
generalSection,
|
||||
canvasElement,
|
||||
@@ -203,6 +259,26 @@ export const EachOverrideSetToEnabledModel: Story = {
|
||||
);
|
||||
});
|
||||
|
||||
await selectModelInSection(
|
||||
titleSection,
|
||||
canvasElement,
|
||||
/gpt 4o mini/i,
|
||||
"Claude Sonnet 4",
|
||||
);
|
||||
const titleSaveButton = within(titleSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(titleSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(titleSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith(
|
||||
{ model_config_id: claudeSonnetModelConfig.id },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
const exploreClearButton = within(exploreSection).getByRole("button", {
|
||||
name: "Clear",
|
||||
});
|
||||
@@ -228,18 +304,25 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = {
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
is_malformed: true,
|
||||
}),
|
||||
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||
is_malformed: true,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
is_malformed: true,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const titleSection = await getSection(
|
||||
canvasElement,
|
||||
"Title generation model",
|
||||
);
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
);
|
||||
|
||||
for (const section of [generalSection, exploreSection]) {
|
||||
for (const section of [generalSection, titleSection, exploreSection]) {
|
||||
await within(section).findByText(OVERRIDE_MALFORMED_WARNING);
|
||||
}
|
||||
|
||||
@@ -257,6 +340,20 @@ export const MalformedOverridesRemainClearableAndSaveable: Story = {
|
||||
);
|
||||
});
|
||||
|
||||
const titleSaveButton = within(titleSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(titleSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(titleSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveTitleGenerationModel).toHaveBeenCalledWith(
|
||||
{ model_config_id: "" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
const exploreSaveButton = within(exploreSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
@@ -278,12 +375,19 @@ export const UnavailableSavedModels: Story = {
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
model_config_id: generalDisabledModelConfig.id,
|
||||
}),
|
||||
titleGenerationModelOverrideData: buildTitleGenerationModelOverrideData({
|
||||
model_config_id: titleDisabledModelConfig.id,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
model_config_id: exploreDisabledModelConfig.id,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const titleSection = await getSection(
|
||||
canvasElement,
|
||||
"Title generation model",
|
||||
);
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
@@ -295,5 +399,13 @@ export const UnavailableSavedModels: Story = {
|
||||
within(section).getByRole("combobox", { name: "Unavailable model" }),
|
||||
).toBeInTheDocument();
|
||||
}
|
||||
await within(titleSection).findByText(
|
||||
TITLE_UNAVAILABLE_SAVED_MODEL_WARNING,
|
||||
);
|
||||
expect(
|
||||
within(titleSection).getByRole("combobox", {
|
||||
name: "Unavailable model",
|
||||
}),
|
||||
).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
@@ -7,19 +7,23 @@ import {
|
||||
} from "./components/SubagentModelOverrideSettings";
|
||||
|
||||
type SaveModelOverride = (
|
||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
||||
req: { readonly model_config_id: string },
|
||||
options?: MutationCallbacks,
|
||||
) => void;
|
||||
|
||||
export interface AgentSettingsAgentsPageViewProps {
|
||||
generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
||||
exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
||||
generalModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||
titleGenerationModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||
exploreModelOverrideData?: TypesGen.ChatModelOverrideResponse;
|
||||
modelConfigsData: TypesGen.ChatModelConfig[] | undefined;
|
||||
modelConfigsError: unknown;
|
||||
isLoadingModelConfigs: boolean;
|
||||
onSaveGeneralModelOverride?: SaveModelOverride;
|
||||
isSavingGeneralModelOverride?: boolean;
|
||||
isSaveGeneralModelOverrideError?: boolean;
|
||||
onSaveTitleGenerationModel: SaveModelOverride;
|
||||
isSavingTitleGenerationModel: boolean;
|
||||
isSaveTitleGenerationModelError: boolean;
|
||||
onSaveExploreModelOverride: SaveModelOverride;
|
||||
isSavingExploreModelOverride: boolean;
|
||||
isSaveExploreModelOverrideError: boolean;
|
||||
@@ -29,6 +33,7 @@ export const AgentSettingsAgentsPageView: FC<
|
||||
AgentSettingsAgentsPageViewProps
|
||||
> = ({
|
||||
generalModelOverrideData,
|
||||
titleGenerationModelOverrideData,
|
||||
exploreModelOverrideData,
|
||||
modelConfigsData,
|
||||
modelConfigsError,
|
||||
@@ -36,6 +41,9 @@ export const AgentSettingsAgentsPageView: FC<
|
||||
onSaveGeneralModelOverride,
|
||||
isSavingGeneralModelOverride = false,
|
||||
isSaveGeneralModelOverrideError = false,
|
||||
onSaveTitleGenerationModel,
|
||||
isSavingTitleGenerationModel,
|
||||
isSaveTitleGenerationModelError,
|
||||
onSaveExploreModelOverride,
|
||||
isSavingExploreModelOverride,
|
||||
isSaveExploreModelOverrideError,
|
||||
@@ -77,6 +85,31 @@ export const AgentSettingsAgentsPageView: FC<
|
||||
/>
|
||||
</section>
|
||||
)}
|
||||
<section
|
||||
aria-label="Title generation model"
|
||||
className="flex flex-col gap-3"
|
||||
>
|
||||
<SectionHeader
|
||||
label="Title generation model"
|
||||
description="Choose a model for generated chat titles. Leave unset to use Coder's default title algorithm, which currently tries fast title models for configured providers first, for example Claude Haiku, GPT-4o mini, and Gemini Flash, then falls back to the chat's current model. When a model is selected here, Coder uses only that model for title generation. Recommended title models are fast and low cost."
|
||||
level="section"
|
||||
/>
|
||||
<SubagentModelOverrideSettings
|
||||
title="Title generation model"
|
||||
description="Choose a model for generated chat titles."
|
||||
modelOverrideData={titleGenerationModelOverrideData}
|
||||
enabledModelConfigs={enabledModelConfigs}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoading={isLoadingModelConfigs}
|
||||
onSaveModelOverride={onSaveTitleGenerationModel}
|
||||
isSaving={isSavingTitleGenerationModel}
|
||||
isSaveError={isSaveTitleGenerationModelError}
|
||||
saveErrorMessage="Failed to save title generation model."
|
||||
unsetPlaceholder="Use title default"
|
||||
unavailableModelWarning="The selected model is currently unavailable. Title generation will be skipped until you choose another model or clear this setting."
|
||||
showHeader={false}
|
||||
/>
|
||||
</section>
|
||||
<section
|
||||
aria-label="Explore subagent model"
|
||||
className="flex flex-col gap-3"
|
||||
|
||||
@@ -161,9 +161,17 @@ const AgentsRouteElement = () => (
|
||||
model_config_id: "",
|
||||
is_malformed: false,
|
||||
}}
|
||||
titleGenerationModelOverrideData={{
|
||||
context: "title_generation",
|
||||
model_config_id: "",
|
||||
is_malformed: false,
|
||||
}}
|
||||
modelConfigsData={[]}
|
||||
modelConfigsError={undefined}
|
||||
isLoadingModelConfigs={false}
|
||||
onSaveTitleGenerationModel={fn()}
|
||||
isSavingTitleGenerationModel={false}
|
||||
isSaveTitleGenerationModelError={false}
|
||||
onSaveExploreModelOverride={fn()}
|
||||
isSavingExploreModelOverride={false}
|
||||
isSaveExploreModelOverrideError={false}
|
||||
|
||||
@@ -22,7 +22,7 @@ interface UpdateModelOverrideRequest {
|
||||
|
||||
interface SubagentModelOverrideSettingsProps {
|
||||
title: string;
|
||||
description: ReactNode;
|
||||
description?: ReactNode;
|
||||
modelOverrideData: ModelOverrideData | undefined;
|
||||
enabledModelConfigs: readonly TypesGen.ChatModelConfig[];
|
||||
modelConfigsError: unknown;
|
||||
@@ -34,6 +34,8 @@ interface SubagentModelOverrideSettingsProps {
|
||||
isSaving: boolean;
|
||||
isSaveError: boolean;
|
||||
saveErrorMessage: string;
|
||||
unsetPlaceholder?: string;
|
||||
unavailableModelWarning?: string;
|
||||
showHeader?: boolean;
|
||||
disabled?: boolean;
|
||||
}
|
||||
@@ -61,6 +63,8 @@ export const SubagentModelOverrideSettings: FC<
|
||||
isSaving,
|
||||
isSaveError,
|
||||
saveErrorMessage,
|
||||
unsetPlaceholder = "Use chat default",
|
||||
unavailableModelWarning = "The saved model is no longer enabled and will be ignored until you choose a new override.",
|
||||
showHeader = true,
|
||||
disabled = false,
|
||||
}) => {
|
||||
@@ -104,9 +108,11 @@ export const SubagentModelOverrideSettings: FC<
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
{title}
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
{description}
|
||||
</p>
|
||||
{description && (
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
{description}
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<ModelSelector
|
||||
@@ -115,7 +121,7 @@ export const SubagentModelOverrideSettings: FC<
|
||||
onValueChange={(value) => form.setFieldValue("model_config_id", value)}
|
||||
disabled={isModelOverrideDisabled}
|
||||
placeholder={
|
||||
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
|
||||
isUnavailableSavedModel ? "Unavailable model" : unsetPlaceholder
|
||||
}
|
||||
emptyMessage={
|
||||
isLoading ? "Loading models..." : "No enabled models found."
|
||||
@@ -125,10 +131,7 @@ export const SubagentModelOverrideSettings: FC<
|
||||
/>
|
||||
{isUnavailableSavedModel && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
The saved model is no longer enabled and will be ignored until you
|
||||
choose a new override.
|
||||
</AlertDescription>
|
||||
<AlertDescription>{unavailableModelWarning}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{isMalformedOverride && (
|
||||
|
||||
Reference in New Issue
Block a user