feat: add general subagent model override (#24610)

Adds a deployment-wide admin override for general delegated subagents.

## What changed
- store the general override in `site_configs` and expose it through the
shared `agent-model-override/{context}` API
- apply the general override when spawning delegated general subagents,
while preserving the existing Explore override behavior
- reuse a shared Agents settings form for the general and Explore
override sections

## Validation
- `make gen`
- `go test ./coderd -run 'TestChatModelOverrides'`
- `go test ./coderd/x/chatd -run
'TestSpawnAgent_(GeneralUsesConfiguredModelOverride|GeneralOverrideLogsAndFallsBackWhenCredentialsUnavailable|GeneralOverrideLogsAndFallsBackWhenProviderDisabled)'`
- `pnpm -C site lint:types`
- `pnpm -C site test:storybook --
AgentSettingsAgentsPageView.stories.tsx`
- `make lint`
- `make pre-commit`

> Mux is acting on Mike's behalf.
This commit is contained in:
Michael Suchacz
2026-04-24 12:37:20 +02:00
committed by GitHub
parent 4505278a9f
commit 3d90546aae
23 changed files with 1782 additions and 679 deletions
+2 -2
View File
@@ -1184,8 +1184,8 @@ func New(options *Options) *API {
r.Put("/system-prompt", api.putChatSystemPrompt)
r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions)
r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions)
r.Get("/explore-model-override", api.getChatExploreModelOverride)
r.Put("/explore-model-override", api.putChatExploreModelOverride)
r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride)
r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride)
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/debug-logging", api.getChatDebugLogging)
+14
View File
@@ -2737,6 +2737,13 @@ func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
return files, nil
}
func (q *querier) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return "", err
}
return q.db.GetChatGeneralModelOverride(ctx)
}
func (q *querier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
// The include-default-system-prompt flag is a deployment-wide setting read
// during chat creation by every authenticated user, so no RBAC policy
@@ -7390,6 +7397,13 @@ func (q *querier) UpsertChatExploreModelOverride(ctx context.Context, value stri
return q.db.UpsertChatExploreModelOverride(ctx, value)
}
func (q *querier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatGeneralModelOverride(ctx, value)
}
func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+8
View File
@@ -890,6 +890,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatGeneralModelOverride(gomock.Any()).Return("", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
@@ -1205,6 +1209,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatGeneralModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
+16
View File
@@ -1256,6 +1256,14 @@ func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUI
return r0, r1
}
func (m queryMetricsStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetChatGeneralModelOverride(ctx)
m.queryLatencies.WithLabelValues("GetChatGeneralModelOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatGeneralModelOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatIncludeDefaultSystemPrompt(ctx)
@@ -5288,6 +5296,14 @@ func (m queryMetricsStore) UpsertChatExploreModelOverride(ctx context.Context, v
return r0
}
func (m queryMetricsStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertChatGeneralModelOverride(ctx, value)
m.queryLatencies.WithLabelValues("UpsertChatGeneralModelOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatGeneralModelOverride").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
start := time.Now()
r0 := m.s.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
+29
View File
@@ -2311,6 +2311,21 @@ func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
}
// GetChatGeneralModelOverride mocks base method.
func (m *MockStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatGeneralModelOverride", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatGeneralModelOverride indicates an expected call of GetChatGeneralModelOverride.
func (mr *MockStoreMockRecorder) GetChatGeneralModelOverride(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatGeneralModelOverride), ctx)
}
// GetChatIncludeDefaultSystemPrompt mocks base method.
func (m *MockStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
@@ -9933,6 +9948,20 @@ func (mr *MockStoreMockRecorder) UpsertChatExploreModelOverride(ctx, value any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatExploreModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatExploreModelOverride), ctx, value)
}
// UpsertChatGeneralModelOverride mocks base method.
func (m *MockStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatGeneralModelOverride", ctx, value)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatGeneralModelOverride indicates an expected call of UpsertChatGeneralModelOverride.
func (mr *MockStoreMockRecorder) UpsertChatGeneralModelOverride(ctx, value any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatGeneralModelOverride), ctx, value)
}
// UpsertChatIncludeDefaultSystemPrompt mocks base method.
func (m *MockStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
m.ctrl.T.Helper()
+2
View File
@@ -308,6 +308,7 @@ type sqlcQuerier interface {
// loading file content.
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
GetChatGeneralModelOverride(ctx context.Context) (string, error)
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
// for deployments created before the explicit include-default toggle.
// When the toggle is unset, a non-empty custom prompt implies false;
@@ -1174,6 +1175,7 @@ type sqlcQuerier interface {
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatExploreModelOverride(ctx context.Context, value string) error
UpsertChatGeneralModelOverride(ctx context.Context, value string) error
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
UpsertChatPlanModeInstructions(ctx context.Context, value string) error
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
+22
View File
@@ -20406,6 +20406,18 @@ func (q *sqlQuerier) GetChatExploreModelOverride(ctx context.Context) (string, e
return model_config_id, err
}
const getChatGeneralModelOverride = `-- name: GetChatGeneralModelOverride :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id
`
func (q *sqlQuerier) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getChatGeneralModelOverride)
var model_config_id string
err := row.Scan(&model_config_id)
return model_config_id, err
}
const getChatIncludeDefaultSystemPrompt = `-- name: GetChatIncludeDefaultSystemPrompt :one
SELECT
COALESCE(
@@ -20773,6 +20785,16 @@ func (q *sqlQuerier) UpsertChatExploreModelOverride(ctx context.Context, value s
return err
}
const upsertChatGeneralModelOverride = `-- name: UpsertChatGeneralModelOverride :exec
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override'
`
func (q *sqlQuerier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, upsertChatGeneralModelOverride, value)
return err
}
const upsertChatIncludeDefaultSystemPrompt = `-- name: UpsertChatIncludeDefaultSystemPrompt :exec
INSERT INTO site_configs (key, value)
VALUES (
+8
View File
@@ -175,6 +175,14 @@ SELECT
INSERT INTO site_configs (key, value) VALUES ('agents_chat_explore_model_override', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_explore_model_override';
-- name: GetChatGeneralModelOverride :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id;
-- name: UpsertChatGeneralModelOverride :exec
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override';
-- name: GetChatDesktopEnabled :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
+145 -24
View File
@@ -449,7 +449,7 @@ func validateChatPlanMode(mode codersdk.ChatPlanMode) bool {
}
}
func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) {
func parseChatModelOverride(raw string) (*uuid.UUID, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
//nolint:nilnil // Empty site-config value means the override is unset.
@@ -457,19 +457,29 @@ func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) {
}
modelConfigID, err := uuid.Parse(trimmed)
if err != nil {
return nil, xerrors.Errorf("parse explore model override: %w", err)
return nil, xerrors.Errorf("parse chat model override: %w", err)
}
return &modelConfigID, nil
}
func formatChatExploreModelOverride(id *uuid.UUID) string {
func formatChatModelOverride(id *uuid.UUID) string {
if id == nil {
return ""
}
return id.String()
}
func validateChatExploreModelOverrideID(
func lookupEnabledChatModelConfigByID(
ctx context.Context,
db database.Store,
id uuid.UUID,
) (database.ChatModelConfig, error) {
//nolint:gocritic // Validation lookup uses AsChatd to check model
// availability independently of the caller's read permissions.
return db.GetEnabledChatModelConfigByID(dbauthz.AsChatd(ctx), id)
}
func validateChatModelOverrideID(
ctx context.Context,
db database.Store,
id *uuid.UUID,
@@ -482,9 +492,7 @@ func validateChatExploreModelOverrideID(
Message: "Invalid model_config_id.",
}
}
//nolint:gocritic // Validation lookup uses system context to check model
// availability independently of the caller's read permissions.
_, err := db.GetEnabledChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), *id)
_, err := lookupEnabledChatModelConfigByID(ctx, db, *id)
if err == nil {
return 0, nil
}
@@ -499,18 +507,23 @@ func validateChatExploreModelOverrideID(
}
}
func (api *API) getChatExploreModelOverrideConfig(
func (api *API) getChatModelOverrideConfig(
ctx context.Context,
settingName string,
getter func(context.Context) (string, error),
) (*uuid.UUID, bool, error) {
raw, err := api.Database.GetChatExploreModelOverride(ctx)
raw, err := getter(ctx)
if err != nil {
return nil, false, xerrors.Errorf("get explore model override: %w", err)
return nil, false, xerrors.Errorf("get %s model override: %w", settingName, err)
}
id, err := parseChatExploreModelOverride(raw)
id, err := parseChatModelOverride(raw)
if err != nil {
// Degrade malformed values to unset so the admin settings page
// remains accessible and the bad value can be cleared.
api.Logger.Warn(ctx, "malformed explore model override in site config, treating as unset",
api.Logger.Warn(
ctx,
"malformed model override in site config, treating as unset",
slog.F("setting", settingName),
slog.F("raw_value", raw),
slog.Error(err),
)
@@ -519,6 +532,64 @@ func (api *API) getChatExploreModelOverrideConfig(
return id, false, nil
}
func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) {
overrideContext := codersdk.ChatAgentModelOverrideContext(raw)
if overrideContext.Valid() {
return overrideContext, nil
}
return "", xerrors.Errorf("unknown chat agent model override context %q", raw)
}
type chatAgentModelOverrideSiteConfig struct {
getter func(context.Context) (string, error)
upsert func(context.Context, string) error
}
func (api *API) chatAgentModelOverrideSiteConfig(
overrideContext codersdk.ChatAgentModelOverrideContext,
) (chatAgentModelOverrideSiteConfig, error) {
switch overrideContext {
case codersdk.ChatAgentModelOverrideContextGeneral:
return chatAgentModelOverrideSiteConfig{
getter: api.Database.GetChatGeneralModelOverride,
upsert: api.Database.UpsertChatGeneralModelOverride,
}, nil
case codersdk.ChatAgentModelOverrideContextExplore:
return chatAgentModelOverrideSiteConfig{
getter: api.Database.GetChatExploreModelOverride,
upsert: api.Database.UpsertChatExploreModelOverride,
}, nil
default:
return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf(
"unknown chat agent model override context %q",
overrideContext,
)
}
}
func (api *API) getChatAgentModelOverrideConfig(
ctx context.Context,
overrideContext codersdk.ChatAgentModelOverrideContext,
) (*uuid.UUID, bool, error) {
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
if err != nil {
return nil, false, err
}
return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter)
}
func (api *API) upsertChatAgentModelOverrideConfig(
ctx context.Context,
overrideContext codersdk.ChatAgentModelOverrideContext,
modelConfigID *uuid.UUID,
) error {
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
if err != nil {
return err
}
return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -3827,53 +3898,103 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ
rw.WriteHeader(http.StatusNoContent)
}
func readChatAgentModelOverrideContext(
rw http.ResponseWriter,
r *http.Request,
) (codersdk.ChatAgentModelOverrideContext, bool) {
ctx := r.Context()
rawContext := chi.URLParam(r, "context")
overrideContext, err := parseChatAgentModelOverrideContext(rawContext)
if err == nil {
return overrideContext, true
}
validContextValues := make(
[]string,
0,
len(codersdk.AllChatAgentModelOverrideContexts()),
)
for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() {
validContextValues = append(validContextValues, string(overrideContext))
}
validContexts := strings.Join(validContextValues, ", ")
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat agent model override context.",
Detail: fmt.Sprintf(
"Expected one of %s. Got %q.",
validContexts,
rawContext,
),
})
return "", false
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) {
func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
httpapi.ResourceNotFound(rw)
return
}
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
if !ok {
return
}
modelConfigID, hasMalformedOverride, err := api.getChatExploreModelOverrideConfig(ctx)
modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching Explore model override.",
Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext),
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatExploreModelOverrideResponse{
ModelConfigID: modelConfigID,
HasMalformedOverride: hasMalformedOverride,
})
resp := codersdk.ChatAgentModelOverrideResponse{
Context: overrideContext,
ModelConfigID: formatChatModelOverride(modelConfigID),
IsMalformed: isMalformed,
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) putChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) {
func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
if !ok {
return
}
var req codersdk.UpdateChatExploreModelOverrideRequest
var req codersdk.UpdateChatAgentModelOverrideRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
status, resp := validateChatExploreModelOverrideID(ctx, api.Database, req.ModelConfigID)
modelConfigID, err := parseChatModelOverride(req.ModelConfigID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid model_config_id.",
Detail: fmt.Sprintf("Value %q is not a valid UUID.", req.ModelConfigID),
})
return
}
status, resp := validateChatModelOverrideID(ctx, api.Database, modelConfigID)
if resp != nil {
httpapi.Write(ctx, rw, status, *resp)
return
}
if err := api.Database.UpsertChatExploreModelOverride(ctx, formatChatExploreModelOverride(req.ModelConfigID)); err != nil {
if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating Explore model override.",
Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext),
Detail: err.Error(),
})
return
+310 -132
View File
@@ -1134,43 +1134,63 @@ func TestListChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
client, _ := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
_ = createChatModelConfig(t, client)
// Insert the chat that will later be pinned directly
// into the database with a completed status so we
// avoid the background chatd processor entirely.
pinnedDBChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "pinned-chat",
Status: database.ChatStatusCompleted,
ClientType: database.ChatClientTypeUi,
// Create the chat that will later be pinned. It gets the
// earliest updated_at because it is inserted first.
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "pinned-chat",
}},
})
require.NoError(t, err)
// Fill page 1 with newer chats so the pinned chat
// would normally be pushed off the first page
// (default limit 50). Insert directly into the
// database to avoid spawning 51 background chat
// processors, which causes timeouts under -race.
// Fill page 1 with newer chats so the pinned chat would
// normally be pushed off the first page (default limit 50).
const fillerCount = 51
fillerChats := make([]codersdk.Chat, 0, fillerCount)
for i := range fillerCount {
_, insertErr := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: fmt.Sprintf("filler-%d", i),
Status: database.ChatStatusCompleted,
ClientType: database.ChatClientTypeUi,
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("filler-%d", i),
}},
})
require.NoError(t, insertErr)
require.NoError(t, createErr)
fillerChats = append(fillerChats, c)
}
// Wait for all chats to reach a terminal status so
// updated_at is stable before paginating. A single
// polling loop checks every chat per tick to avoid
// O(N) separate Eventually loops.
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
pending := make(map[uuid.UUID]struct{}, len(allCreated))
for _, c := range allCreated {
pending[c.ID] = struct{}{}
}
testutil.Eventually(ctx, t, func(_ context.Context) bool {
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
})
if listErr != nil {
return false
}
for _, ch := range all {
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
delete(pending, ch.ID)
}
}
return len(pending) == 0
}, testutil.IntervalFast)
// Pin the earliest chat.
err = client.UpdateChat(ctx, pinnedDBChat.ID, codersdk.UpdateChatRequest{
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
@@ -1186,11 +1206,11 @@ func TestListChats(t *testing.T) {
for _, c := range page1 {
page1IDs[c.ID] = struct{}{}
}
_, found := page1IDs[pinnedDBChat.ID]
_, found := page1IDs[pinnedChat.ID]
require.True(t, found, "pinned chat should appear on page 1")
// The pinned chat should be the first item in the list.
require.Equal(t, pinnedDBChat.ID, page1[0].ID, "pinned chat should be first")
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
})
// Test cursor pagination with a mix of pinned and unpinned chats.
@@ -9396,6 +9416,44 @@ func createChatModelConfig(t *testing.T, client *codersdk.ExperimentalClient) co
return modelConfig
}
func createAdditionalChatModelConfig(
t *testing.T,
client *codersdk.ExperimentalClient,
provider string,
model string,
) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
contextLimit := int64(4096)
isDefault := false
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider,
Model: model,
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
return modelConfig
}
func createDisabledChatModelConfig(
t *testing.T,
client *codersdk.ExperimentalClient,
provider string,
model string,
) codersdk.ChatModelConfig {
t.Helper()
modelConfig := createAdditionalChatModelConfig(t, client, provider, model)
ctx := testutil.Context(t, testutil.WaitLong)
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Enabled: ptr.Ref(false),
})
require.NoError(t, err)
return updated
}
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
func TestChatSystemPrompt(t *testing.T) {
t.Parallel()
@@ -9897,129 +9955,249 @@ func TestChatPlanModeInstructions(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
func TestChatExploreModelOverride(t *testing.T) {
//nolint:tparallel,paralleltest // Setting subtests share per-setting coderdtest instances.
func TestChatModelOverrides(t *testing.T) {
t.Parallel()
adminClient, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
defaultModel := createChatModelConfig(t, adminClient)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
createAdditionalModel := func(t *testing.T, model string, enabled bool) codersdk.ChatModelConfig {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
contextLimit := int64(4096)
isDefault := false
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: defaultModel.Provider,
Model: model,
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
if enabled {
return modelConfig
}
updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Enabled: ptr.Ref(false),
})
require.NoError(t, err)
return updated
type overrideResponse struct {
context codersdk.ChatAgentModelOverrideContext
modelConfigID string
isMalformed bool
}
t.Run("DefaultGETReturnsEmpty", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
type settingTest struct {
name string
context codersdk.ChatAgentModelOverrideContext
dbGet func(context.Context, database.Store) (string, error)
dbUpsert func(context.Context, database.Store, string) error
}
resp, err := adminClient.GetChatExploreModelOverride(ctx)
require.NoError(t, err)
require.Nil(t, resp.ModelConfigID)
require.False(t, resp.HasMalformedOverride)
})
settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string {
return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext)
}
t.Run("AdminCanSetAndClear", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
overrideModel := createAdditionalModel(t, "gpt-4.1-mini", true)
getOverride := func(
ctx context.Context,
client *codersdk.ExperimentalClient,
overrideContext codersdk.ChatAgentModelOverrideContext,
) (overrideResponse, error) {
resp, err := client.GetChatAgentModelOverride(ctx, overrideContext)
if err != nil {
return overrideResponse{}, err
}
return overrideResponse{
context: resp.Context,
modelConfigID: resp.ModelConfigID,
isMalformed: resp.IsMalformed,
}, nil
}
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
ModelConfigID: &overrideModel.ID,
putOverride := func(
ctx context.Context,
client *codersdk.ExperimentalClient,
overrideContext codersdk.ChatAgentModelOverrideContext,
modelConfigID string,
) error {
return client.UpdateChatAgentModelOverride(
ctx,
overrideContext,
codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID},
)
}
settings := []settingTest{
{
name: "General",
context: codersdk.ChatAgentModelOverrideContextGeneral,
dbGet: func(ctx context.Context, db database.Store) (string, error) {
return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx))
},
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
return db.UpsertChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx), value)
},
},
{
name: "Explore",
context: codersdk.ChatAgentModelOverrideContextExplore,
dbGet: func(ctx context.Context, db database.Store) (string, error) {
return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx))
},
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value)
},
},
}
for _, setting := range settings {
t.Run(setting.name, func(t *testing.T) {
adminClient, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
defaultModel := createChatModelConfig(t, adminClient)
openAIModel := createAdditionalChatModelConfig(
t,
adminClient,
defaultModel.Provider,
"gpt-4.1-mini-"+string(setting.context),
)
disabledModel := createDisabledChatModelConfig(
t,
adminClient,
defaultModel.Provider,
"gpt-4.1-disabled-"+string(setting.context),
)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
t.Run("DefaultGETReturnsEmpty", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
resp, err := getOverride(ctx, adminClient, setting.context)
require.NoError(t, err)
require.Equal(t, setting.context, resp.context)
require.Empty(t, resp.modelConfigID)
require.False(t, resp.isMalformed)
raw, err := setting.dbGet(ctx, db)
require.NoError(t, err)
require.Empty(t, raw, "expected empty stored override for %s", settingPath(setting.context))
})
t.Run("AdminCanSetAndClear", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := putOverride(ctx, adminClient, setting.context, openAIModel.ID.String())
require.NoError(t, err)
raw, err := setting.dbGet(ctx, db)
require.NoError(t, err)
require.Equal(t, openAIModel.ID.String(), raw, "expected stored override for %s", settingPath(setting.context))
resp, err := getOverride(ctx, adminClient, setting.context)
require.NoError(t, err)
require.Equal(t, setting.context, resp.context)
require.Equal(t, openAIModel.ID.String(), resp.modelConfigID)
require.False(t, resp.isMalformed)
err = putOverride(ctx, adminClient, setting.context, "")
require.NoError(t, err)
raw, err = setting.dbGet(ctx, db)
require.NoError(t, err)
require.Empty(t, raw, "expected cleared override for %s", settingPath(setting.context))
resp, err = getOverride(ctx, adminClient, setting.context)
require.NoError(t, err)
require.Equal(t, setting.context, resp.context)
require.Empty(t, resp.modelConfigID)
require.False(t, resp.isMalformed)
})
t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
require.NoError(t, setting.dbUpsert(ctx, db, "not-a-uuid"))
resp, err := getOverride(ctx, adminClient, setting.context)
require.NoError(t, err)
require.Equal(t, setting.context, resp.context)
require.Empty(t, resp.modelConfigID)
require.True(t, resp.isMalformed)
err = putOverride(ctx, adminClient, setting.context, "")
require.NoError(t, err)
raw, err := setting.dbGet(ctx, db)
require.NoError(t, err)
require.Empty(t, raw, "expected malformed override to be cleared for %s", settingPath(setting.context))
resp, err = getOverride(ctx, adminClient, setting.context)
require.NoError(t, err)
require.Equal(t, setting.context, resp.context)
require.Empty(t, resp.modelConfigID)
require.False(t, resp.isMalformed)
})
t.Run("InvalidUUIDReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := putOverride(ctx, adminClient, setting.context, "not-a-uuid")
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
require.Equal(t, "Value \"not-a-uuid\" is not a valid UUID.", sdkErr.Detail)
})
t.Run("DisabledModelReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := putOverride(ctx, adminClient, setting.context, disabledModel.ID.String())
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
})
t.Run("UnknownModelReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
unknownModelID := uuid.New()
err := putOverride(ctx, adminClient, setting.context, unknownModelID.String())
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
})
t.Run("NonAdminGETReturns404", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
_, err := getOverride(ctx, memberClient, setting.context)
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("NonAdminPUTReturns403", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := putOverride(ctx, memberClient, setting.context, defaultModel.ID.String())
requireSDKError(t, err, http.StatusForbidden)
})
})
require.NoError(t, err)
}
resp, err := adminClient.GetChatExploreModelOverride(ctx)
require.NoError(t, err)
require.NotNil(t, resp.ModelConfigID)
require.Equal(t, overrideModel.ID, *resp.ModelConfigID)
require.False(t, resp.HasMalformedOverride)
err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{})
require.NoError(t, err)
resp, err = adminClient.GetChatExploreModelOverride(ctx)
require.NoError(t, err)
require.Nil(t, resp.ModelConfigID)
require.False(t, resp.HasMalformedOverride)
})
t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) {
t.Run("UnknownContextReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
require.NoError(t, db.UpsertChatExploreModelOverride(
dbauthz.AsSystemRestricted(ctx),
"not-a-uuid",
))
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient.Client)
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
resp, err := adminClient.GetChatExploreModelOverride(ctx)
require.NoError(t, err)
require.Nil(t, resp.ModelConfigID)
require.True(t, resp.HasMalformedOverride)
err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{})
require.NoError(t, err)
resp, err = adminClient.GetChatExploreModelOverride(ctx)
require.NoError(t, err)
require.Nil(t, resp.ModelConfigID)
require.False(t, resp.HasMalformedOverride)
})
t.Run("DisabledModelReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
disabledModel := createAdditionalModel(t, "gpt-4.1-disabled", false)
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
ModelConfigID: &disabledModel.ID,
})
_, err := getOverride(ctx, adminClient, unknownContext)
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
require.Equal(
t,
`Expected one of general, explore. Got "not-a-context".`,
sdkErr.Detail,
)
err = putOverride(ctx, adminClient, unknownContext, "")
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
require.Equal(
t,
`Expected one of general, explore. Got "not-a-context".`,
sdkErr.Detail,
)
})
t.Run("UnknownModelReturns400", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
unknownModelID := uuid.New()
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
ModelConfigID: &unknownModelID,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
})
t.Run("NonAdminGETReturns404", func(t *testing.T) {
t.Run("NonAdminUnknownContextUsesAuthResponse", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
_, err := memberClient.GetChatExploreModelOverride(ctx)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
_, err := getOverride(ctx, memberClient, unknownContext)
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("NonAdminPUTReturns403", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := memberClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
ModelConfigID: &defaultModel.ID,
})
err = putOverride(ctx, memberClient, unknownContext, "")
requireSDKError(t, err, http.StatusForbidden)
})
}
+187 -42
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
"sort"
@@ -27,6 +28,18 @@ import (
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
var errInvalidModelOverrideMetadata = xerrors.New("invalid model override metadata")
type modelOverrideConfigResolver func(
context.Context,
uuid.UUID,
) (database.ChatModelConfig, string, error)
type modelOverrideProviderKeysResolver func(
context.Context,
uuid.UUID,
) (chatprovider.ProviderAPIKeys, error)
const (
subagentAwaitPollInterval = 200 * time.Millisecond
subagentAwaitFallbackPoll = 5 * time.Second
@@ -90,66 +103,199 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool {
return enabled
}
func (p *Server) resolveExploreSubagentModelConfigID(
func subagentModelOverrideLogLabel(
overrideContext codersdk.ChatAgentModelOverrideContext,
) string {
switch overrideContext {
case codersdk.ChatAgentModelOverrideContextGeneral:
return "general delegated child"
case codersdk.ChatAgentModelOverrideContextExplore:
return "explore"
default:
return string(overrideContext)
}
}
func readSubagentModelOverride(
ctx context.Context,
ownerID uuid.UUID,
fallback uuid.UUID,
) (uuid.UUID, error) {
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
chatdCtx := dbauthz.AsChatd(ctx)
raw, err := p.db.GetChatExploreModelOverride(chatdCtx)
if err != nil {
return uuid.Nil, xerrors.Errorf("get Explore model override: %w", err)
}
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return fallback, nil
}
configuredModelConfigID, err := uuid.Parse(trimmed)
if err != nil {
p.logger.Warn(ctx,
"invalid Explore model override, falling back to current turn model",
slog.F("raw_model_config_id", trimmed),
slog.Error(err),
db database.Store,
overrideContext codersdk.ChatAgentModelOverrideContext,
) (string, error) {
switch overrideContext {
case codersdk.ChatAgentModelOverrideContextGeneral:
return db.GetChatGeneralModelOverride(ctx)
case codersdk.ChatAgentModelOverrideContextExplore:
return db.GetChatExploreModelOverride(ctx)
default:
return "", xerrors.Errorf(
"unknown subagent model override context %q",
overrideContext,
)
return fallback, nil
}
modelConfig, err := p.db.GetEnabledChatModelConfigByID(
chatdCtx,
configuredModelConfigID,
)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
p.logger.Warn(ctx,
"explore model override is unavailable, falling back to current turn model",
slog.F("model_config_id", configuredModelConfigID),
)
return fallback, nil
}
return uuid.Nil, xerrors.Errorf("get enabled chat model config by id: %w", err)
}
func validateModelConfigAndResolveProvider(
modelConfig database.ChatModelConfig,
) (database.ChatModelConfig, string, error) {
if !modelConfig.Enabled {
return database.ChatModelConfig{}, "", sql.ErrNoRows
}
providerName, _, err := chatprovider.ResolveModelWithProviderHint(
modelConfig.Model,
modelConfig.Provider,
)
if err != nil {
return uuid.Nil, xerrors.Errorf("resolve Explore model provider: %w", err)
return database.ChatModelConfig{}, "", xerrors.Errorf(
"%w: %v",
errInvalidModelOverrideMetadata,
err,
)
}
providerKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID)
return modelConfig, providerName, nil
}
func enabledProviderContainsName(
providers []database.ChatProvider,
providerName string,
) bool {
normalizedProviderName := chatprovider.NormalizeProvider(providerName)
for _, provider := range providers {
if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName {
return true
}
}
return false
}
func (p *Server) resolveConfiguredModelOverride(
ctx context.Context,
overrideContext string,
raw string,
ownerID uuid.UUID,
resolveModelConfig modelOverrideConfigResolver,
resolveProviderKeys modelOverrideProviderKeysResolver,
) (database.ChatModelConfig, bool, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return database.ChatModelConfig{}, false, nil
}
configuredModelConfigID, err := uuid.Parse(trimmed)
if err != nil {
return uuid.Nil, xerrors.Errorf("resolve provider API keys: %w", err)
p.logger.Info(ctx,
"invalid model override, ignoring",
slog.F("override_context", overrideContext),
slog.F("raw_model_config_id", trimmed),
slog.Error(err),
)
return database.ChatModelConfig{}, false, nil
}
if providerKeys.APIKey(providerName) == "" {
p.logger.Warn(ctx,
"explore model override credentials are unavailable, falling back to current turn model",
modelConfig, providerName, err := resolveModelConfig(
ctx,
configuredModelConfigID,
)
if err != nil {
switch {
case xerrors.Is(err, sql.ErrNoRows):
p.logger.Info(ctx,
"model override is unavailable, ignoring",
slog.F("override_context", overrideContext),
slog.F("model_config_id", configuredModelConfigID),
)
case errors.Is(err, errInvalidModelOverrideMetadata):
p.logger.Info(ctx,
"model override metadata is invalid, ignoring",
slog.F("override_context", overrideContext),
slog.F("model_config_id", configuredModelConfigID),
slog.Error(err),
)
default:
p.logger.Warn(ctx,
"failed to resolve model override, ignoring",
slog.F("override_context", overrideContext),
slog.F("model_config_id", configuredModelConfigID),
slog.Error(err),
)
}
return database.ChatModelConfig{}, false, nil
}
providerKeys, err := resolveProviderKeys(ctx, ownerID)
if err != nil {
return database.ChatModelConfig{}, false, xerrors.Errorf(
"resolve provider API keys: %w",
err,
)
}
if providerKeys.APIKey(providerName) == "" &&
!(chatprovider.ProviderAllowsAmbientCredentials(providerName) &&
providerKeys.HasProvider(providerName)) {
p.logger.Info(ctx,
"model override credentials are unavailable, ignoring",
slog.F("override_context", overrideContext),
slog.F("model_config_id", configuredModelConfigID),
slog.F("provider", providerName),
)
return fallback, nil
return database.ChatModelConfig{}, false, nil
}
return modelConfig, true, nil
}
func (p *Server) resolveSubagentModelConfigID(
ctx context.Context,
ownerID uuid.UUID,
overrideContext codersdk.ChatAgentModelOverrideContext,
) (uuid.UUID, error) {
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
chatdCtx := dbauthz.AsChatd(ctx)
raw, err := readSubagentModelOverride(chatdCtx, p.db, overrideContext)
if err != nil {
return uuid.Nil, xerrors.Errorf(
"get %s model override: %w",
subagentModelOverrideLogLabel(overrideContext),
err,
)
}
modelConfig, ok, err := p.resolveConfiguredModelOverride(
ctx,
string(overrideContext),
raw,
ownerID,
p.resolveModelConfigAndNormalizedProvider,
p.resolveUserProviderAPIKeys,
)
if err != nil {
return uuid.Nil, err
}
if !ok {
return uuid.Nil, nil
}
return modelConfig.ID, nil
}
func (p *Server) resolveModelConfigAndNormalizedProvider(
ctx context.Context,
modelConfigID uuid.UUID,
) (database.ChatModelConfig, string, error) {
if modelConfigID == uuid.Nil {
return database.ChatModelConfig{}, "", sql.ErrNoRows
}
modelConfig, err := p.configCache.ModelConfigByID(ctx, modelConfigID)
if err != nil {
return database.ChatModelConfig{}, "", err
}
modelConfig, providerName, err := validateModelConfigAndResolveProvider(modelConfig)
if err != nil {
return database.ChatModelConfig{}, "", err
}
enabledProviders, err := p.configCache.EnabledProviders(ctx)
if err != nil {
return database.ChatModelConfig{}, "", err
}
if !enabledProviderContainsName(enabledProviders, providerName) {
return database.ChatModelConfig{}, "", sql.ErrNoRows
}
return modelConfig, providerName, nil
}
func (p *Server) subagentTools(
ctx context.Context,
currentChat func() database.Chat,
@@ -444,7 +590,6 @@ func (p *Server) loadSubagentSpawnParentChat(
if err := validateSubagentSpawnParent(parent); err != nil {
return database.Chat{}, err
}
reloadedParent, err := p.db.GetChatByID(ctx, parent.ID)
if err != nil {
p.logger.Warn(ctx, "failed to load parent chat for spawn_agent",
+20 -4
View File
@@ -9,6 +9,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
const (
@@ -43,22 +44,37 @@ func allSubagentDefinitions() []subagentDefinition {
{
id: subagentTypeGeneral,
description: "delegated work that may inspect or modify workspace files",
buildOptions: func(_ context.Context, _ *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) {
return childSubagentChatOptions{}, nil
buildOptions: func(ctx context.Context, p *Server, parent database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) {
modelConfigID, err := p.resolveSubagentModelConfigID(
ctx,
parent.OwnerID,
codersdk.ChatAgentModelOverrideContextGeneral,
)
if err != nil {
return childSubagentChatOptions{}, err
}
options := childSubagentChatOptions{}
if modelConfigID != uuid.Nil {
options.modelConfigIDOverride = &modelConfigID
}
return options, nil
},
},
{
id: subagentTypeExplore,
description: "read-only discovery, code tracing, and system understanding",
buildOptions: func(ctx context.Context, p *Server, _ database.Chat, turnParent database.Chat, currentModelConfigID uuid.UUID, _ string) (childSubagentChatOptions, error) {
modelConfigID, err := p.resolveExploreSubagentModelConfigID(
modelConfigID, err := p.resolveSubagentModelConfigID(
ctx,
turnParent.OwnerID,
currentModelConfigID,
codersdk.ChatAgentModelOverrideContextExplore,
)
if err != nil {
return childSubagentChatOptions{}, err
}
if modelConfigID == uuid.Nil {
modelConfigID = currentModelConfigID
}
inheritedMCPServerIDs, err := p.resolveExploreToolSnapshot(
ctx,
turnParent,
+325 -5
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -71,7 +73,14 @@ func newInternalTestServer(
ps pubsub.Pubsub,
keys chatprovider.ProviderAPIKeys,
) *Server {
return newInternalTestServerWithClock(t, db, ps, keys, nil)
return newInternalTestServerWithLoggerAndClock(
t,
db,
ps,
keys,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
nil,
)
}
func newInternalTestServerWithClock(
@@ -80,10 +89,37 @@ func newInternalTestServerWithClock(
ps pubsub.Pubsub,
keys chatprovider.ProviderAPIKeys,
clk quartz.Clock,
) *Server {
return newInternalTestServerWithLoggerAndClock(
t,
db,
ps,
keys,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clk,
)
}
func newInternalTestServerWithLogger(
t *testing.T,
db database.Store,
ps pubsub.Pubsub,
keys chatprovider.ProviderAPIKeys,
logger slog.Logger,
) *Server {
return newInternalTestServerWithLoggerAndClock(t, db, ps, keys, logger, nil)
}
func newInternalTestServerWithLoggerAndClock(
t *testing.T,
db database.Store,
ps pubsub.Pubsub,
keys chatprovider.ProviderAPIKeys,
logger slog.Logger,
clk quartz.Clock,
) *Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := New(Config{
Logger: logger,
Database: db,
@@ -101,6 +137,35 @@ func newInternalTestServerWithClock(
return server
}
type subagentTestLogSink struct {
mu sync.Mutex
entries []slog.SinkEntry
}
func (s *subagentTestLogSink) LogEntry(_ context.Context, entry slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, entry)
}
func (*subagentTestLogSink) Sync() {}
func (s *subagentTestLogSink) entriesAtLevelWithMessage(
level slog.Level,
message string,
) []slog.SinkEntry {
s.mu.Lock()
defer s.mu.Unlock()
entries := make([]slog.SinkEntry, 0, len(s.entries))
for _, entry := range s.entries {
if entry.Level == level && entry.Message == message {
entries = append(entries, entry)
}
}
return entries
}
// seedInternalChatDeps inserts an OpenAI provider and model config
// into the database and returns the created user, organization,
// and model. This deliberately does NOT create an Anthropic
@@ -218,6 +283,54 @@ func insertInternalChatModelConfig(
userID uuid.UUID,
model string,
enabled bool,
) database.ChatModelConfig {
return insertInternalChatModelConfigForProvider(
ctx,
t,
db,
userID,
"openai",
model,
enabled,
)
}
func insertInternalChatProvider(
ctx context.Context,
t *testing.T,
db database.Store,
userID uuid.UUID,
provider string,
apiKey string,
centralAPIKeyEnabled bool,
allowUserAPIKey bool,
allowCentralAPIKeyFallback bool,
) database.ChatProvider {
t.Helper()
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: provider,
APIKey: apiKey,
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: centralAPIKeyEnabled,
AllowUserApiKey: allowUserAPIKey,
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
})
require.NoError(t, err)
return providerConfig
}
func insertInternalChatModelConfigForProvider(
ctx context.Context,
t *testing.T,
db database.Store,
userID uuid.UUID,
provider string,
model string,
enabled bool,
) database.ChatModelConfig {
t.Helper()
return insertInternalChatModelConfigWithOptions(
@@ -225,6 +338,7 @@ func insertInternalChatModelConfig(
t,
db,
userID,
provider,
model,
enabled,
json.RawMessage(`{}`),
@@ -236,6 +350,7 @@ func insertInternalChatModelConfigWithOptions(
t *testing.T,
db database.Store,
userID uuid.UUID,
provider string,
model string,
enabled bool,
options json.RawMessage,
@@ -243,7 +358,7 @@ func insertInternalChatModelConfigWithOptions(
t.Helper()
modelConfig, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Provider: provider,
Model: model,
DisplayName: model,
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
@@ -574,6 +689,213 @@ func TestSpawnAgent_GeneralInheritsParentModelWhenOmitted(t *testing.T) {
require.Equal(t, parentChat.LastModelConfigID, childChat.LastModelConfigID)
}
func TestSpawnAgent_GeneralUsesConfiguredModelOverride(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, org, model := seedInternalChatDeps(ctx, t, db)
overrideModel := insertInternalChatModelConfig(
ctx, t, db, user.ID, "general-override-"+uuid.NewString(), true,
)
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
parentChat := createInternalParentChat(
ctx, t, server, db, org.ID, user.ID, model.ID, "parent-general-override",
)
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
Type: subagentTypeGeneral,
Prompt: "delegate general work",
})
childID := requireSpawnAgentChildChatID(t, resp)
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.Equal(t, overrideModel.ID, childChat.LastModelConfigID)
require.False(t, childChat.PlanMode.Valid)
}
func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenCredentialsUnavailable(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
logSink := &subagentTestLogSink{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
server := newInternalTestServerWithLogger(t, db, ps, chatprovider.ProviderAPIKeys{}, logger)
ctx := chatdTestContext(t)
user, org, model := seedInternalChatDeps(ctx, t, db)
insertInternalChatProvider(
ctx,
t,
db,
user.ID,
"openai-compat",
"",
false,
true,
false,
)
overrideModel := insertInternalChatModelConfigForProvider(
ctx,
t,
db,
user.ID,
"openai-compat",
"gpt-4o-mini",
true,
)
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
parent, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "parent-general-credentials-fallback",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("delegate work"),
},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
Type: subagentTypeGeneral,
Prompt: "inspect provider credentials",
})
childID := requireSpawnAgentChildChatID(t, resp)
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.Equal(t, model.ID, childChat.LastModelConfigID)
require.False(t, childChat.PlanMode.Valid)
require.Len(t, logSink.entriesAtLevelWithMessage(
slog.LevelInfo,
"model override credentials are unavailable, ignoring",
), 1)
}
func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenProviderDisabled(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
logSink := &subagentTestLogSink{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
server := newInternalTestServerWithLogger(
t,
db,
ps,
chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
"openai-compat": "fallback-key",
},
},
logger,
)
ctx := chatdTestContext(t)
user, org, model := seedInternalChatDeps(ctx, t, db)
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai-compat",
DisplayName: "openai-compat",
APIKey: "",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: false,
CentralApiKeyEnabled: false,
AllowUserApiKey: true,
AllowCentralApiKeyFallback: false,
})
require.NoError(t, err)
overrideModel := insertInternalChatModelConfigForProvider(
ctx,
t,
db,
user.ID,
"openai-compat",
"gpt-4o-mini",
true,
)
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
parent, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "parent-general-disabled-provider-fallback",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("delegate work"),
},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
Type: subagentTypeGeneral,
Prompt: "inspect disabled providers",
})
childID := requireSpawnAgentChildChatID(t, resp)
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.Equal(t, model.ID, childChat.LastModelConfigID)
require.False(t, childChat.PlanMode.Valid)
require.Len(t, logSink.entriesAtLevelWithMessage(
slog.LevelInfo,
"model override is unavailable, ignoring",
), 1)
}
func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider(
t *testing.T,
) {
t.Parallel()
logSink := &subagentTestLogSink{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
server := &Server{logger: logger}
ctx := chatdTestContext(t)
ownerID := uuid.New()
modelConfig := database.ChatModelConfig{
ID: uuid.New(),
Provider: "bedrock",
Model: "anthropic.claude-haiku-4-5-20251001-v1:0",
DisplayName: "Ambient Bedrock Override",
Enabled: true,
}
resolvedModelConfig, ok, err := server.resolveConfiguredModelOverride(
ctx,
"plan",
modelConfig.ID.String(),
ownerID,
func(
_ context.Context,
configuredModelConfigID uuid.UUID,
) (database.ChatModelConfig, string, error) {
require.Equal(t, modelConfig.ID, configuredModelConfigID)
return modelConfig, "bedrock", nil
},
func(
_ context.Context,
resolvedOwnerID uuid.UUID,
) (chatprovider.ProviderAPIKeys, error) {
require.Equal(t, ownerID, resolvedOwnerID)
return chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{"bedrock": ""},
}, nil
},
)
require.NoError(t, err)
require.True(t, ok)
require.Equal(t, modelConfig, resolvedModelConfig)
require.Empty(t, logSink.entriesAtLevelWithMessage(
slog.LevelInfo,
"model override credentials are unavailable, ignoring",
))
}
func TestCreateChildSubagentChat_OverrideWorksWhenParentHasNoModel(t *testing.T) {
t.Parallel()
@@ -1328,7 +1650,6 @@ func TestSubagentLifecycleToolsIncludePersistedSubagentTypeAcrossVariants(t *tes
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
@@ -1450,7 +1771,6 @@ func TestSubagentLifecycleToolErrorsIncludePersistedSubagentType(t *testing.T) {
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
+57 -22
View File
@@ -562,19 +562,46 @@ type UpdateChatPlanModeInstructionsRequest struct {
PlanModeInstructions string `json:"plan_mode_instructions"`
}
// ChatExploreModelOverrideResponse is the response body for the Explore
// subagent model override configuration endpoint.
type ChatExploreModelOverrideResponse struct {
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
// HasMalformedOverride reports whether the saved override is malformed and
// is currently being treated as unset.
HasMalformedOverride bool `json:"has_malformed_override"`
// ChatAgentModelOverrideContext identifies which chat or subagent context
// a deployment override applies to.
type ChatAgentModelOverrideContext string
const (
ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general"
ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore"
)
// Valid reports whether the override context is one of the supported values.
func (c ChatAgentModelOverrideContext) Valid() bool {
switch c {
case ChatAgentModelOverrideContextGeneral,
ChatAgentModelOverrideContextExplore:
return true
default:
return false
}
}
// UpdateChatExploreModelOverrideRequest is the request body for updating the
// Explore subagent model override configuration endpoint.
type UpdateChatExploreModelOverrideRequest struct {
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
// AllChatAgentModelOverrideContexts returns all supported override contexts.
func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext {
return []ChatAgentModelOverrideContext{
ChatAgentModelOverrideContextGeneral,
ChatAgentModelOverrideContextExplore,
}
}
// ChatAgentModelOverrideResponse is the response body for the chat agent
// model override configuration endpoint.
type ChatAgentModelOverrideResponse struct {
Context ChatAgentModelOverrideContext `json:"context"`
ModelConfigID string `json:"model_config_id"`
IsMalformed bool `json:"is_malformed"`
}
// UpdateChatAgentModelOverrideRequest is the request body for updating the
// chat agent model override configuration endpoint.
type UpdateChatAgentModelOverrideRequest struct {
ModelConfigID string `json:"model_config_id"`
}
// UserChatCustomPrompt is the request and response body for the
@@ -2024,25 +2051,33 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context,
return nil
}
// GetChatExploreModelOverride returns the deployment-wide Explore subagent
// model override.
func (c *ExperimentalClient) GetChatExploreModelOverride(ctx context.Context) (ChatExploreModelOverrideResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/explore-model-override", nil)
// GetChatAgentModelOverride returns the deployment-wide chat agent model
// override for the requested context.
func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) {
path := fmt.Sprintf(
"/api/experimental/chats/config/agent-model-override/%s",
url.PathEscape(string(override)),
)
res, err := c.Request(ctx, http.MethodGet, path, nil)
if err != nil {
return ChatExploreModelOverrideResponse{}, err
return ChatAgentModelOverrideResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatExploreModelOverrideResponse{}, ReadBodyAsError(res)
return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res)
}
var resp ChatExploreModelOverrideResponse
var resp ChatAgentModelOverrideResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// UpdateChatExploreModelOverride updates the deployment-wide Explore subagent
// model override.
func (c *ExperimentalClient) UpdateChatExploreModelOverride(ctx context.Context, req UpdateChatExploreModelOverrideRequest) error {
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/explore-model-override", req)
// UpdateChatAgentModelOverride updates the deployment-wide chat agent model
// override for the requested context.
func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error {
path := fmt.Sprintf(
"/api/experimental/chats/config/agent-model-override/%s",
url.PathEscape(string(override)),
)
res, err := c.Request(ctx, http.MethodPut, path, req)
if err != nil {
return err
}
+13 -11
View File
@@ -3257,20 +3257,22 @@ class ExperimentalApiMethods {
);
};
getChatExploreModelOverride =
async (): Promise<TypesGen.ChatExploreModelOverrideResponse> => {
const response =
await this.axios.get<TypesGen.ChatExploreModelOverrideResponse>(
"/api/experimental/chats/config/explore-model-override",
);
return response.data;
};
getChatAgentModelOverride = async (
context: TypesGen.ChatAgentModelOverrideContext,
): Promise<TypesGen.ChatAgentModelOverrideResponse> => {
const response =
await this.axios.get<TypesGen.ChatAgentModelOverrideResponse>(
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
);
return response.data;
};
updateChatExploreModelOverride = async (
req: TypesGen.UpdateChatExploreModelOverrideRequest,
updateChatAgentModelOverride = async (
context: TypesGen.ChatAgentModelOverrideContext,
req: TypesGen.UpdateChatAgentModelOverrideRequest,
): Promise<void> => {
await this.axios.put(
"/api/experimental/chats/config/explore-model-override",
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
req,
);
};
-17
View File
@@ -1127,23 +1127,6 @@ export const updateChatPlanModeInstructions = (queryClient: QueryClient) => ({
},
});
const chatExploreModelOverrideKey = ["chat-explore-model-override"] as const;
export const chatExploreModelOverride = () => ({
queryKey: chatExploreModelOverrideKey,
queryFn: () => API.experimental.getChatExploreModelOverride(),
});
export const updateChatExploreModelOverride = (queryClient: QueryClient) => ({
mutationFn: (req: TypesGen.UpdateChatExploreModelOverrideRequest) =>
API.experimental.updateChatExploreModelOverride(req),
onSuccess: async () => {
await queryClient.invalidateQueries({
queryKey: chatExploreModelOverrideKey,
});
},
});
const chatDesktopEnabledKey = ["chat-desktop-enabled"] as const;
export const chatDesktopEnabled = () => ({
+28 -23
View File
@@ -1272,6 +1272,25 @@ export interface Chat {
readonly children: readonly Chat[];
}
// From codersdk/chats.go
export type ChatAgentModelOverrideContext = "explore" | "general";
export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [
"explore",
"general",
];
// From codersdk/chats.go
/**
* ChatAgentModelOverrideResponse is the response body for the chat agent
* model override configuration endpoint.
*/
export interface ChatAgentModelOverrideResponse {
readonly context: ChatAgentModelOverrideContext;
readonly model_config_id: string;
readonly is_malformed: boolean;
}
// From codersdk/chats.go
export type ChatBusyBehavior = "interrupt" | "queue";
@@ -1590,20 +1609,6 @@ export interface ChatDiffStatus {
readonly stale_at?: string;
}
// From codersdk/chats.go
/**
* ChatExploreModelOverrideResponse is the response body for the Explore
* subagent model override configuration endpoint.
*/
export interface ChatExploreModelOverrideResponse {
readonly model_config_id?: string;
/**
* HasMalformedOverride reports whether the saved override is malformed and
* is currently being treated as unset.
*/
readonly has_malformed_override: boolean;
}
// From codersdk/chats.go
/**
* ChatFileMetadata contains lightweight metadata about a file
@@ -7657,6 +7662,15 @@ export interface UpdateAppearanceConfig {
readonly announcement_banners: readonly BannerConfig[];
}
// From codersdk/chats.go
/**
* UpdateChatAgentModelOverrideRequest is the request body for updating the
* chat agent model override configuration endpoint.
*/
export interface UpdateChatAgentModelOverrideRequest {
readonly model_config_id: string;
}
// From codersdk/chats.go
/**
* UpdateChatDebugLoggingAllowUsersRequest is the admin request to
@@ -7674,15 +7688,6 @@ export interface UpdateChatDesktopEnabledRequest {
readonly enable_desktop: boolean;
}
// From codersdk/chats.go
/**
* UpdateChatExploreModelOverrideRequest is the request body for updating the
* Explore subagent model override configuration endpoint.
*/
export interface UpdateChatExploreModelOverrideRequest {
readonly model_config_id?: string;
}
// From codersdk/chats.go
/**
* UpdateChatModelConfigRequest updates a chat model config.
@@ -1,34 +1,83 @@
import type { FC } from "react";
import { useMutation, useQuery, useQueryClient } from "react-query";
import {
chatExploreModelOverride,
chatModelConfigs,
updateChatExploreModelOverride,
} from "#/api/queries/chats";
type QueryClient,
useMutation,
useQuery,
useQueryClient,
} from "react-query";
import { API } from "#/api/api";
import { chatModelConfigs } from "#/api/queries/chats";
import type * as TypesGen from "#/api/typesGenerated";
import { useAuthenticated } from "#/hooks/useAuthenticated";
import { RequirePermission } from "#/modules/permissions/RequirePermission";
import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView";
const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext =
"general";
const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext =
"explore";
const chatAgentModelOverrideKey = (
context: TypesGen.ChatAgentModelOverrideContext,
) => ["chat-agent-model-override", context] as const;
const chatAgentModelOverrideQuery = (
context: TypesGen.ChatAgentModelOverrideContext,
) => ({
queryKey: chatAgentModelOverrideKey(context),
queryFn: () => API.experimental.getChatAgentModelOverride(context),
});
const updateChatAgentModelOverrideMutation = (
queryClient: QueryClient,
context: TypesGen.ChatAgentModelOverrideContext,
) => ({
mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) =>
API.experimental.updateChatAgentModelOverride(context, req),
onSuccess: async () => {
await queryClient.invalidateQueries({
queryKey: chatAgentModelOverrideKey(context),
exact: true,
});
},
});
const AgentSettingsAgentsPage: FC = () => {
const { permissions } = useAuthenticated();
const queryClient = useQueryClient();
const canEditDeploymentConfig = permissions.editDeploymentConfig;
const generalModelOverrideQuery = useQuery({
...chatAgentModelOverrideQuery(generalOverrideContext),
enabled: canEditDeploymentConfig,
});
const exploreModelOverrideQuery = useQuery({
...chatExploreModelOverride(),
enabled: permissions.editDeploymentConfig,
...chatAgentModelOverrideQuery(exploreOverrideContext),
enabled: canEditDeploymentConfig,
});
const modelConfigsQuery = useQuery(chatModelConfigs());
const saveGeneralModelOverrideMutation = useMutation(
updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext),
);
const saveExploreModelOverrideMutation = useMutation(
updateChatExploreModelOverride(queryClient),
updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext),
);
return (
<RequirePermission isFeatureVisible={permissions.editDeploymentConfig}>
<RequirePermission isFeatureVisible={canEditDeploymentConfig}>
<AgentSettingsAgentsPageView
generalModelOverrideData={generalModelOverrideQuery.data}
exploreModelOverrideData={exploreModelOverrideQuery.data}
modelConfigsData={modelConfigsQuery.data}
modelConfigsError={modelConfigsQuery.error}
isLoadingModelConfigs={modelConfigsQuery.isLoading}
onSaveGeneralModelOverride={saveGeneralModelOverrideMutation.mutate}
isSavingGeneralModelOverride={
saveGeneralModelOverrideMutation.isPending
}
isSaveGeneralModelOverrideError={
saveGeneralModelOverrideMutation.isError
}
onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate}
isSavingExploreModelOverride={
saveExploreModelOverrideMutation.isPending
@@ -6,208 +6,294 @@ import {
type AgentSettingsAgentsPageViewProps,
} from "./AgentSettingsAgentsPageView";
const baseArgs: AgentSettingsAgentsPageViewProps = {
exploreModelOverrideData: {
has_malformed_override: false,
},
modelConfigsData: [],
const OVERRIDE_MALFORMED_WARNING =
"The saved override is malformed and is being treated as unset. Click Save to clear it.";
const UNAVAILABLE_SAVED_MODEL_WARNING =
"The saved model is no longer enabled and will be ignored until you choose a new override.";
const buildModelConfig = (
overrides: Partial<TypesGen.ChatModelConfig>,
): TypesGen.ChatModelConfig => ({
id: "model-default",
provider: "openai",
model: "gpt-4.1-mini",
display_name: "GPT 4.1 Mini",
enabled: true,
is_default: false,
context_limit: 1_000_000,
compression_threshold: 70,
created_at: "2026-03-12T12:00:00.000Z",
updated_at: "2026-03-12T12:00:00.000Z",
...overrides,
});
const buildOverrideData = (
context: TypesGen.ChatAgentModelOverrideContext,
overrides: Partial<TypesGen.ChatAgentModelOverrideResponse> = {},
): TypesGen.ChatAgentModelOverrideResponse => ({
context,
model_config_id: "",
is_malformed: false,
...overrides,
});
const generalModelConfig = buildModelConfig({
id: "model-general-gpt-4.1-mini",
display_name: "GPT 4.1 Mini",
});
const claudeSonnetModelConfig = buildModelConfig({
id: "model-claude-sonnet-4",
provider: "anthropic",
model: "claude-sonnet-4",
display_name: "Claude Sonnet 4",
context_limit: 200_000,
});
const exploreFallbackModelConfig = buildModelConfig({
id: "model-explore-blank-display",
provider: "anthropic",
model: "claude-sonnet-4-20250514",
display_name: "",
context_limit: 200_000,
});
const generalDisabledModelConfig = buildModelConfig({
id: "model-general-disabled",
model: "gpt-4.1-legacy",
display_name: "GPT 4.1 Legacy",
enabled: false,
});
const exploreDisabledModelConfig = buildModelConfig({
id: "model-explore-disabled",
provider: "anthropic",
model: "claude-haiku-legacy",
display_name: "Claude Haiku Legacy",
enabled: false,
context_limit: 200_000,
});
const allModelConfigs: TypesGen.ChatModelConfig[] = [
generalModelConfig,
claudeSonnetModelConfig,
exploreFallbackModelConfig,
generalDisabledModelConfig,
exploreDisabledModelConfig,
];
const makeArgs = (
overrides: Partial<AgentSettingsAgentsPageViewProps> = {},
): AgentSettingsAgentsPageViewProps => ({
generalModelOverrideData: buildOverrideData("general"),
exploreModelOverrideData: buildOverrideData("explore"),
modelConfigsData: allModelConfigs,
modelConfigsError: undefined,
isLoadingModelConfigs: false,
onSaveGeneralModelOverride: fn(),
isSavingGeneralModelOverride: false,
isSaveGeneralModelOverrideError: false,
onSaveExploreModelOverride: fn(),
isSavingExploreModelOverride: false,
isSaveExploreModelOverrideError: false,
...overrides,
});
const getSection = async (
canvasElement: HTMLElement,
headingName: string,
): Promise<HTMLElement> => {
const canvas = within(canvasElement);
const heading = await canvas.findByRole("heading", { name: headingName });
const section = heading.closest("section");
if (!(section instanceof HTMLElement)) {
throw new Error(
`Expected ${headingName} heading to live inside a section.`,
);
}
return section;
};
const selectModelInSection = async (
section: HTMLElement,
canvasElement: HTMLElement,
currentSelectionName: string | RegExp,
optionName: string,
) => {
const trigger = within(section).getByRole("combobox", {
name: currentSelectionName,
});
await userEvent.click(trigger);
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(await body.findByRole("option", { name: optionName }));
};
const meta = {
title: "pages/AgentsPage/AgentSettingsAgentsPageView",
component: AgentSettingsAgentsPageView,
args: baseArgs,
args: makeArgs(),
} satisfies Meta<typeof AgentSettingsAgentsPageView>;
export default meta;
type Story = StoryObj<typeof AgentSettingsAgentsPageView>;
export const ExploreModelOverrideSetting: Story = {
args: {
exploreModelOverrideData: {
model_config_id: "model-explore-1",
has_malformed_override: false,
},
modelConfigsData: [
{
id: "model-explore-1",
provider: "openai",
model: "gpt-4.1-mini",
display_name: "GPT 4.1 Mini",
enabled: true,
is_default: false,
context_limit: 1_000_000,
compression_threshold: 70,
created_at: "2026-03-12T12:00:00.000Z",
updated_at: "2026-03-12T12:00:00.000Z",
},
{
id: "model-explore-2",
provider: "anthropic",
model: "claude-sonnet-4",
display_name: "Claude Sonnet 4",
enabled: true,
is_default: false,
context_limit: 200_000,
compression_threshold: 70,
created_at: "2026-03-12T12:00:00.000Z",
updated_at: "2026-03-12T12:00:00.000Z",
},
] as TypesGen.ChatModelConfig[],
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
await canvas.findByText("Agents");
await canvas.findByText("Explore subagent model");
const trigger = canvas.getByRole("combobox", {
name: /gpt 4.1 mini/i,
});
await userEvent.click(trigger);
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(
await body.findByRole("option", { name: "Claude Sonnet 4" }),
);
const form = trigger.closest("form");
if (!(form instanceof HTMLFormElement)) {
throw new Error("Expected Explore model selector to live inside a form.");
}
const saveButton = within(form).getByRole("button", { name: "Save" });
await waitFor(() => {
expect(saveButton).toBeEnabled();
});
await userEvent.click(saveButton);
await waitFor(() => {
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
{ model_config_id: "model-explore-2" },
expect.anything(),
);
});
},
};
export const ExploreModelOverrideAllowsExplicitClear: Story = {
args: {
exploreModelOverrideData: {
model_config_id: "model-explore-clear",
has_malformed_override: false,
},
modelConfigsData: [
{
id: "model-explore-clear",
provider: "openai",
model: "gpt-4.1-mini",
display_name: "GPT 4.1 Mini",
enabled: true,
is_default: false,
context_limit: 1_000_000,
compression_threshold: 70,
created_at: "2026-03-12T12:00:00.000Z",
updated_at: "2026-03-12T12:00:00.000Z",
},
] as TypesGen.ChatModelConfig[],
onSaveExploreModelOverride: fn(),
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
const clearButton = await canvas.findByRole("button", { name: "Clear" });
const form = clearButton.closest("form");
if (!(form instanceof HTMLFormElement)) {
throw new Error(
"Expected Explore model clear button to live inside a form.",
);
}
const saveButton = within(form).getByRole("button", { name: "Save" });
await userEvent.click(clearButton);
expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled();
await waitFor(() => {
expect(saveButton).toBeEnabled();
});
await userEvent.click(saveButton);
await waitFor(() => {
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
{},
expect.anything(),
);
});
},
};
export const ExploreModelOverrideClearsMalformedSavedValue: Story = {
args: {
exploreModelOverrideData: {
has_malformed_override: true,
},
modelConfigsData: [],
onSaveExploreModelOverride: fn(),
},
play: async ({ canvasElement, args }) => {
const canvas = within(canvasElement);
await canvas.findByText(
"The saved override is malformed and is being treated as unset. Click Save to clear it.",
);
const clearButton = await canvas.findByRole("button", { name: "Clear" });
const form = clearButton.closest("form");
if (!(form instanceof HTMLFormElement)) {
throw new Error(
"Expected Explore model clear button to live inside a form.",
);
}
const saveButton = within(form).getByRole("button", { name: "Save" });
await waitFor(() => {
expect(saveButton).toBeEnabled();
});
await userEvent.click(clearButton);
expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled();
await userEvent.click(saveButton);
await waitFor(() => {
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
{},
expect.anything(),
);
});
},
};
export const ExploreModelOverrideFallsBackToModelName: Story = {
args: {
exploreModelOverrideData: {
model_config_id: "model-explore-empty-name",
has_malformed_override: false,
},
modelConfigsData: [
{
id: "model-explore-empty-name",
provider: "anthropic",
model: "claude-sonnet-4-20250514",
display_name: "",
enabled: true,
is_default: false,
context_limit: 200_000,
compression_threshold: 70,
created_at: "2026-03-12T12:00:00.000Z",
updated_at: "2026-03-12T12:00:00.000Z",
},
] as TypesGen.ChatModelConfig[],
},
export const AllOverridesUnset: Story = {
args: makeArgs(),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const trigger = await canvas.findByRole("combobox", {
name: /claude-sonnet-4-20250514/i,
});
expect(trigger).toHaveTextContent("claude-sonnet-4-20250514");
await userEvent.click(trigger);
const body = within(canvasElement.ownerDocument.body);
expect(
await body.findByRole("option", {
name: "claude-sonnet-4-20250514",
}),
).toBeInTheDocument();
await canvas.findByText("Agents");
const headings = await canvas.findAllByRole("heading", { level: 3 });
expect(headings.map((heading) => heading.textContent?.trim())).toEqual([
"General model",
"Explore subagent model",
]);
for (const headingName of ["General model", "Explore subagent model"]) {
const section = await getSection(canvasElement, headingName);
expect(
within(section).getByRole("combobox", { name: "Use chat default" }),
).toBeInTheDocument();
expect(
within(section).getByRole("button", { name: "Save" }),
).toBeDisabled();
}
},
};
export const EachOverrideSetToEnabledModel: Story = {
args: makeArgs({
generalModelOverrideData: buildOverrideData("general", {
model_config_id: generalModelConfig.id,
}),
exploreModelOverrideData: buildOverrideData("explore", {
model_config_id: exploreFallbackModelConfig.id,
}),
}),
play: async ({ canvasElement, args }) => {
const generalSection = await getSection(canvasElement, "General model");
const exploreSection = await getSection(
canvasElement,
"Explore subagent model",
);
expect(
within(exploreSection).getByRole("combobox", {
name: /claude-sonnet-4-20250514/i,
}),
).toHaveTextContent("claude-sonnet-4-20250514");
await selectModelInSection(
generalSection,
canvasElement,
/gpt 4\.1 mini/i,
"Claude Sonnet 4",
);
const generalSaveButton = within(generalSection).getByRole("button", {
name: "Save",
});
await waitFor(() => {
expect(generalSaveButton).toBeEnabled();
});
await userEvent.click(generalSaveButton);
await waitFor(() => {
expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith(
{ model_config_id: claudeSonnetModelConfig.id },
expect.anything(),
);
});
const exploreClearButton = within(exploreSection).getByRole("button", {
name: "Clear",
});
await userEvent.click(exploreClearButton);
const exploreSaveButton = within(exploreSection).getByRole("button", {
name: "Save",
});
await waitFor(() => {
expect(exploreSaveButton).toBeEnabled();
});
await userEvent.click(exploreSaveButton);
await waitFor(() => {
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
{ model_config_id: "" },
expect.anything(),
);
});
},
};
export const MalformedOverridesRemainClearableAndSaveable: Story = {
args: makeArgs({
generalModelOverrideData: buildOverrideData("general", {
is_malformed: true,
}),
exploreModelOverrideData: buildOverrideData("explore", {
is_malformed: true,
}),
}),
play: async ({ canvasElement, args }) => {
const generalSection = await getSection(canvasElement, "General model");
const exploreSection = await getSection(
canvasElement,
"Explore subagent model",
);
for (const section of [generalSection, exploreSection]) {
await within(section).findByText(OVERRIDE_MALFORMED_WARNING);
}
const generalSaveButton = within(generalSection).getByRole("button", {
name: "Save",
});
await waitFor(() => {
expect(generalSaveButton).toBeEnabled();
});
await userEvent.click(generalSaveButton);
await waitFor(() => {
expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith(
{ model_config_id: "" },
expect.anything(),
);
});
const exploreSaveButton = within(exploreSection).getByRole("button", {
name: "Save",
});
await waitFor(() => {
expect(exploreSaveButton).toBeEnabled();
});
await userEvent.click(exploreSaveButton);
await waitFor(() => {
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
{ model_config_id: "" },
expect.anything(),
);
});
},
};
export const UnavailableSavedModels: Story = {
args: makeArgs({
generalModelOverrideData: buildOverrideData("general", {
model_config_id: generalDisabledModelConfig.id,
}),
exploreModelOverrideData: buildOverrideData("explore", {
model_config_id: exploreDisabledModelConfig.id,
}),
}),
play: async ({ canvasElement }) => {
const generalSection = await getSection(canvasElement, "General model");
const exploreSection = await getSection(
canvasElement,
"Explore subagent model",
);
for (const section of [generalSection, exploreSection]) {
await within(section).findByText(UNAVAILABLE_SAVED_MODEL_WARNING);
expect(
within(section).getByRole("combobox", { name: "Unavailable model" }),
).toBeInTheDocument();
}
},
};
@@ -1,24 +1,26 @@
import type { FC } from "react";
import type * as TypesGen from "#/api/typesGenerated";
import { ExploreModelOverrideSettings } from "./components/ExploreModelOverrideSettings";
import { SectionHeader } from "./components/SectionHeader";
import {
type MutationCallbacks,
SubagentModelOverrideSettings,
} from "./components/SubagentModelOverrideSettings";
interface MutationCallbacks {
onSuccess?: () => void;
onError?: () => void;
}
type SaveModelOverride = (
req: TypesGen.UpdateChatAgentModelOverrideRequest,
options?: MutationCallbacks,
) => void;
export interface AgentSettingsAgentsPageViewProps {
exploreModelOverrideData:
| TypesGen.ChatExploreModelOverrideResponse
| undefined;
generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
modelConfigsData: TypesGen.ChatModelConfig[] | undefined;
modelConfigsError: unknown;
isLoadingModelConfigs: boolean;
onSaveExploreModelOverride: (
req: TypesGen.UpdateChatExploreModelOverrideRequest,
options?: MutationCallbacks,
) => void;
onSaveGeneralModelOverride?: SaveModelOverride;
isSavingGeneralModelOverride?: boolean;
isSaveGeneralModelOverrideError?: boolean;
onSaveExploreModelOverride: SaveModelOverride;
isSavingExploreModelOverride: boolean;
isSaveExploreModelOverrideError: boolean;
}
@@ -26,37 +28,84 @@ export interface AgentSettingsAgentsPageViewProps {
export const AgentSettingsAgentsPageView: FC<
AgentSettingsAgentsPageViewProps
> = ({
generalModelOverrideData,
exploreModelOverrideData,
modelConfigsData,
modelConfigsError,
isLoadingModelConfigs,
onSaveGeneralModelOverride,
isSavingGeneralModelOverride = false,
isSaveGeneralModelOverrideError = false,
onSaveExploreModelOverride,
isSavingExploreModelOverride,
isSaveExploreModelOverrideError,
}) => {
const enabledModelConfigs = (modelConfigsData ?? []).filter(
(modelConfig) => modelConfig.enabled,
);
const showGeneralModelSection =
onSaveGeneralModelOverride !== undefined ||
generalModelOverrideData !== undefined ||
isSavingGeneralModelOverride ||
isSaveGeneralModelOverrideError;
return (
<div className="flex flex-col gap-8">
<SectionHeader
label="Agents"
description="Configure defaults for delegated agents and other agent-specific capabilities."
/>
<div className="flex flex-col gap-3">
{showGeneralModelSection && onSaveGeneralModelOverride && (
<section aria-label="General model" className="flex flex-col gap-3">
<SectionHeader
label="General model"
description="Deployment-wide model override for delegated subagents with write capabilities, such as editing files or running commands in the workspace."
level="section"
/>
<SubagentModelOverrideSettings
title="General model"
description="Deployment-wide model override for delegated subagents with write capabilities, such as editing files or running commands in the workspace."
modelOverrideData={generalModelOverrideData}
enabledModelConfigs={enabledModelConfigs}
modelConfigsError={modelConfigsError}
isLoading={isLoadingModelConfigs}
onSaveModelOverride={onSaveGeneralModelOverride}
isSaving={isSavingGeneralModelOverride}
isSaveError={isSaveGeneralModelOverrideError}
saveErrorMessage="Failed to save general model override."
showHeader={false}
/>
</section>
)}
<section
aria-label="Explore subagent model"
className="flex flex-col gap-3"
>
<SectionHeader
label="Explore subagent model"
description="Optional deployment-wide model override for read-only Explore subagents."
description="Deployment-wide model override for read-only Explore subagents."
level="section"
/>
<ExploreModelOverrideSettings
exploreModelOverrideData={exploreModelOverrideData}
modelConfigs={modelConfigsData ?? []}
<SubagentModelOverrideSettings
title="Explore subagent model"
description={
<>
Deployment-wide model override for read-only Explore subagents
launched through the <code>spawn_agent</code> tool with a
<code>type=explore</code> argument.
</>
}
modelOverrideData={exploreModelOverrideData}
enabledModelConfigs={enabledModelConfigs}
modelConfigsError={modelConfigsError}
isLoadingModelConfigs={isLoadingModelConfigs}
onSaveExploreModelOverride={onSaveExploreModelOverride}
isSavingExploreModelOverride={isSavingExploreModelOverride}
isSaveExploreModelOverrideError={isSaveExploreModelOverrideError}
isLoading={isLoadingModelConfigs}
onSaveModelOverride={onSaveExploreModelOverride}
isSaving={isSavingExploreModelOverride}
isSaveError={isSaveExploreModelOverrideError}
saveErrorMessage="Failed to save Explore model override."
showHeader={false}
/>
</div>
</section>
</div>
);
};
@@ -156,7 +156,11 @@ const fixedNow = dayjs("2026-03-12T12:00:00");
const AgentsRouteElement = () => (
<AgentSettingsAgentsPageView
exploreModelOverrideData={{ has_malformed_override: false }}
exploreModelOverrideData={{
context: "explore",
model_config_id: "",
is_malformed: false,
}}
modelConfigsData={[]}
modelConfigsError={undefined}
isLoadingModelConfigs={false}
@@ -609,10 +613,23 @@ export const WithErrorReasons: Story = {
const openSettingsView = async (canvasElement: HTMLElement) => {
const canvas = within(canvasElement);
const link = await waitFor(() =>
canvas.getByRole("link", { name: "Settings" }),
const settingsLink = canvas.queryByRole("link", { name: "Settings" });
if (settingsLink) {
await userEvent.click(settingsLink);
return;
}
const mobileMoreOptionsButton = canvas
.getAllByRole("button", { name: "More options" })
.find((button) => button.getAttribute("aria-haspopup") === "menu");
if (!mobileMoreOptionsButton) {
throw new Error("Expected a mobile More options menu button.");
}
await userEvent.click(mobileMoreOptionsButton);
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(
await body.findByRole("menuitem", { name: "Settings" }),
);
await userEvent.click(link);
};
export const OpensAnalyticsForAdmins: Story = {
@@ -1,176 +0,0 @@
import { useFormik } from "formik";
import type { FC } from "react";
import type * as TypesGen from "#/api/typesGenerated";
import { Alert, AlertDescription } from "#/components/Alert/Alert";
import { Button } from "#/components/Button/Button";
import type { ModelSelectorOption } from "./ChatElements/ModelSelector";
import { ModelSelector } from "./ChatElements/ModelSelector";
interface MutationCallbacks {
onSuccess?: () => void;
onError?: () => void;
}
interface ExploreModelOverrideSettingsProps {
exploreModelOverrideData:
| TypesGen.ChatExploreModelOverrideResponse
| undefined;
modelConfigs: readonly TypesGen.ChatModelConfig[];
modelConfigsError: unknown;
isLoadingModelConfigs: boolean;
onSaveExploreModelOverride: (
req: TypesGen.UpdateChatExploreModelOverrideRequest,
options?: MutationCallbacks,
) => void;
isSavingExploreModelOverride: boolean;
isSaveExploreModelOverrideError: boolean;
showHeader?: boolean;
}
const toModelSelectorOption = (
modelConfig: TypesGen.ChatModelConfig,
): ModelSelectorOption => ({
id: modelConfig.id,
provider: modelConfig.provider,
model: modelConfig.model,
displayName: modelConfig.display_name.trim() || modelConfig.model,
contextLimit: modelConfig.context_limit,
});
export const ExploreModelOverrideSettings: FC<
ExploreModelOverrideSettingsProps
> = ({
exploreModelOverrideData,
modelConfigs,
modelConfigsError,
isLoadingModelConfigs,
onSaveExploreModelOverride,
isSavingExploreModelOverride,
isSaveExploreModelOverrideError,
showHeader = true,
}) => {
const hasLoadedExploreModelOverride = exploreModelOverrideData !== undefined;
const enabledModelOptions = modelConfigs
.filter((modelConfig) => modelConfig.enabled)
.map(toModelSelectorOption);
const form = useFormik({
enableReinitialize: true,
initialValues: {
model_config_id: exploreModelOverrideData?.model_config_id ?? "",
},
onSubmit: (values, { resetForm }) => {
onSaveExploreModelOverride(
{
model_config_id: values.model_config_id || undefined,
},
{
onSuccess: () => {
resetForm({ values });
},
},
);
},
});
const isUnavailableSavedModel =
form.values.model_config_id !== "" &&
!enabledModelOptions.some(
(option) => option.id === form.values.model_config_id,
);
const hasMalformedOverride =
exploreModelOverrideData?.has_malformed_override ?? false;
const isExploreModelOverrideDisabled =
isSavingExploreModelOverride ||
isLoadingModelConfigs ||
!hasLoadedExploreModelOverride;
const canSaveExploreModelOverride =
hasLoadedExploreModelOverride && (form.dirty || hasMalformedOverride);
return (
<form className="space-y-2" onSubmit={form.handleSubmit}>
{showHeader && (
<>
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Explore subagent model
</h3>
</div>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Optional deployment-wide model override for read-only Explore
subagents spawned with <code>spawn_agent</code> using
<code>type=explore</code>.
</p>
</>
)}
<div className="rounded-lg border border-border bg-surface-primary px-3 py-2">
<ModelSelector
options={enabledModelOptions}
value={form.values.model_config_id}
onValueChange={(value) =>
form.setFieldValue("model_config_id", value)
}
disabled={isExploreModelOverrideDisabled}
placeholder={
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
}
emptyMessage={
isLoadingModelConfigs
? "Loading models..."
: "No enabled models found."
}
className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm"
contentClassName="min-w-[18rem]"
/>
</div>
{isUnavailableSavedModel && (
<Alert severity="warning">
<AlertDescription>
The saved model is no longer enabled and will be ignored until you
choose a new override.
</AlertDescription>
</Alert>
)}
{hasMalformedOverride && (
<Alert severity="warning">
<AlertDescription>
The saved override is malformed and is being treated as unset. Click
Save to clear it.
</AlertDescription>
</Alert>
)}
{Boolean(modelConfigsError) && (
<p className="m-0 text-xs text-content-destructive">
Failed to load model configs.
</p>
)}
<div className="flex justify-end gap-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={() => {
form.setFieldValue("model_config_id", "");
}}
disabled={isExploreModelOverrideDisabled}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={
isExploreModelOverrideDisabled || !canSaveExploreModelOverride
}
>
Save
</Button>
</div>
{isSaveExploreModelOverrideError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save Explore model override.
</p>
)}
</form>
);
};
@@ -0,0 +1,174 @@
import { useFormik } from "formik";
import type { FC, ReactNode } from "react";
import type * as TypesGen from "#/api/typesGenerated";
import { Alert, AlertDescription } from "#/components/Alert/Alert";
import { Button } from "#/components/Button/Button";
import type { ModelSelectorOption } from "./ChatElements/ModelSelector";
import { ModelSelector } from "./ChatElements/ModelSelector";
export interface MutationCallbacks {
onSuccess?: () => void;
onError?: () => void;
}
interface ModelOverrideData {
readonly model_config_id: string;
readonly is_malformed: boolean;
}
interface UpdateModelOverrideRequest {
readonly model_config_id: string;
}
interface SubagentModelOverrideSettingsProps {
title: string;
description: ReactNode;
modelOverrideData: ModelOverrideData | undefined;
enabledModelConfigs: readonly TypesGen.ChatModelConfig[];
modelConfigsError: unknown;
isLoading: boolean;
onSaveModelOverride: (
req: UpdateModelOverrideRequest,
options?: MutationCallbacks,
) => void;
isSaving: boolean;
isSaveError: boolean;
saveErrorMessage: string;
showHeader?: boolean;
disabled?: boolean;
}
const toModelSelectorOption = (
modelConfig: TypesGen.ChatModelConfig,
): ModelSelectorOption => ({
id: modelConfig.id,
provider: modelConfig.provider,
model: modelConfig.model,
displayName: modelConfig.display_name.trim() || modelConfig.model,
contextLimit: modelConfig.context_limit,
});
export const SubagentModelOverrideSettings: FC<
SubagentModelOverrideSettingsProps
> = ({
title,
description,
modelOverrideData,
enabledModelConfigs,
modelConfigsError,
isLoading,
onSaveModelOverride,
isSaving,
isSaveError,
saveErrorMessage,
showHeader = true,
disabled = false,
}) => {
const hasLoadedModelOverride = modelOverrideData !== undefined;
const enabledModelOptions = enabledModelConfigs.map(toModelSelectorOption);
const form = useFormik({
enableReinitialize: true,
initialValues: {
model_config_id: modelOverrideData?.model_config_id ?? "",
},
onSubmit: (values, { resetForm }) => {
onSaveModelOverride(
{
model_config_id: values.model_config_id,
},
{
onSuccess: () => {
resetForm({ values });
},
},
);
},
});
const isUnavailableSavedModel =
form.values.model_config_id !== "" &&
!enabledModelOptions.some(
(option) => option.id === form.values.model_config_id,
);
const isMalformedOverride = modelOverrideData?.is_malformed ?? false;
const isModelOverrideDisabled =
disabled || isSaving || isLoading || !hasLoadedModelOverride;
const canSaveModelOverride =
hasLoadedModelOverride && (form.dirty || isMalformedOverride);
return (
<form aria-label={title} className="space-y-2" onSubmit={form.handleSubmit}>
{showHeader && (
<>
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
{title}
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
{description}
</p>
</>
)}
<ModelSelector
options={enabledModelOptions}
value={form.values.model_config_id}
onValueChange={(value) => form.setFieldValue("model_config_id", value)}
disabled={isModelOverrideDisabled}
placeholder={
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
}
emptyMessage={
isLoading ? "Loading models..." : "No enabled models found."
}
className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm"
contentClassName="min-w-[18rem]"
/>
{isUnavailableSavedModel && (
<Alert severity="warning">
<AlertDescription>
The saved model is no longer enabled and will be ignored until you
choose a new override.
</AlertDescription>
</Alert>
)}
{isMalformedOverride && (
<Alert severity="warning">
<AlertDescription>
The saved override is malformed and is being treated as unset. Click
Save to clear it.
</AlertDescription>
</Alert>
)}
{Boolean(modelConfigsError) && (
<p className="m-0 text-xs text-content-destructive">
Failed to load model configs.
</p>
)}
<div className="flex justify-end gap-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={() => {
form.setFieldValue("model_config_id", "");
}}
disabled={isModelOverrideDisabled}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isModelOverrideDisabled || !canSaveModelOverride}
>
Save
</Button>
</div>
{isSaveError && (
<p className="m-0 text-xs text-content-destructive">
{saveErrorMessage}
</p>
)}
</form>
);
};