mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add general subagent model override (#24610)
Adds a deployment-wide admin override for general delegated subagents.
## What changed
- store the general override in `site_configs` and expose it through the
shared `agent-model-override/{context}` API
- apply the general override when spawning delegated general subagents,
while preserving the existing Explore override behavior
- reuse a shared Agents settings form for the general and Explore
override sections
## Validation
- `make gen`
- `go test ./coderd -run 'TestChatModelOverrides'`
- `go test ./coderd/x/chatd -run
'TestSpawnAgent_(GeneralUsesConfiguredModelOverride|GeneralOverrideLogsAndFallsBackWhenCredentialsUnavailable|GeneralOverrideLogsAndFallsBackWhenProviderDisabled)'`
- `pnpm -C site lint:types`
- `pnpm -C site test:storybook --
AgentSettingsAgentsPageView.stories.tsx`
- `make lint`
- `make pre-commit`
> Mux is acting on Mike's behalf.
This commit is contained in:
+2
-2
@@ -1184,8 +1184,8 @@ func New(options *Options) *API {
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions)
|
||||
r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions)
|
||||
r.Get("/explore-model-override", api.getChatExploreModelOverride)
|
||||
r.Put("/explore-model-override", api.putChatExploreModelOverride)
|
||||
r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride)
|
||||
r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/debug-logging", api.getChatDebugLogging)
|
||||
|
||||
@@ -2737,6 +2737,13 @@ func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetChatGeneralModelOverride(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
// The include-default-system-prompt flag is a deployment-wide setting read
|
||||
// during chat creation by every authenticated user, so no RBAC policy
|
||||
@@ -7390,6 +7397,13 @@ func (q *querier) UpsertChatExploreModelOverride(ctx context.Context, value stri
|
||||
return q.db.UpsertChatExploreModelOverride(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatGeneralModelOverride(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -890,6 +890,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatGeneralModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
@@ -1205,6 +1209,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatGeneralModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
|
||||
@@ -1256,6 +1256,14 @@ func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUI
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatGeneralModelOverride(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatGeneralModelOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatGeneralModelOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatIncludeDefaultSystemPrompt(ctx)
|
||||
@@ -5288,6 +5296,14 @@ func (m queryMetricsStore) UpsertChatExploreModelOverride(ctx context.Context, v
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatGeneralModelOverride(ctx, value)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatGeneralModelOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatGeneralModelOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
|
||||
@@ -2311,6 +2311,21 @@ func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatGeneralModelOverride mocks base method.
|
||||
func (m *MockStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatGeneralModelOverride", ctx)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatGeneralModelOverride indicates an expected call of GetChatGeneralModelOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatGeneralModelOverride(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatGeneralModelOverride), ctx)
|
||||
}
|
||||
|
||||
// GetChatIncludeDefaultSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9933,6 +9948,20 @@ func (mr *MockStoreMockRecorder) UpsertChatExploreModelOverride(ctx, value any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatExploreModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatExploreModelOverride), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatGeneralModelOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatGeneralModelOverride", ctx, value)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatGeneralModelOverride indicates an expected call of UpsertChatGeneralModelOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatGeneralModelOverride(ctx, value any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatGeneralModelOverride), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatIncludeDefaultSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -308,6 +308,7 @@ type sqlcQuerier interface {
|
||||
// loading file content.
|
||||
GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error)
|
||||
GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error)
|
||||
GetChatGeneralModelOverride(ctx context.Context) (string, error)
|
||||
// GetChatIncludeDefaultSystemPrompt preserves the legacy default
|
||||
// for deployments created before the explicit include-default toggle.
|
||||
// When the toggle is unset, a non-empty custom prompt implies false;
|
||||
@@ -1174,6 +1175,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatExploreModelOverride(ctx context.Context, value string) error
|
||||
UpsertChatGeneralModelOverride(ctx context.Context, value string) error
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatPlanModeInstructions(ctx context.Context, value string) error
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
|
||||
@@ -20406,6 +20406,18 @@ func (q *sqlQuerier) GetChatExploreModelOverride(ctx context.Context) (string, e
|
||||
return model_config_id, err
|
||||
}
|
||||
|
||||
const getChatGeneralModelOverride = `-- name: GetChatGeneralModelOverride :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatGeneralModelOverride(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatGeneralModelOverride)
|
||||
var model_config_id string
|
||||
err := row.Scan(&model_config_id)
|
||||
return model_config_id, err
|
||||
}
|
||||
|
||||
const getChatIncludeDefaultSystemPrompt = `-- name: GetChatIncludeDefaultSystemPrompt :one
|
||||
SELECT
|
||||
COALESCE(
|
||||
@@ -20773,6 +20785,16 @@ func (q *sqlQuerier) UpsertChatExploreModelOverride(ctx context.Context, value s
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatGeneralModelOverride = `-- name: UpsertChatGeneralModelOverride :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertChatGeneralModelOverride, value)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatIncludeDefaultSystemPrompt = `-- name: UpsertChatIncludeDefaultSystemPrompt :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
|
||||
@@ -175,6 +175,14 @@ SELECT
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_explore_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_explore_model_override';
|
||||
|
||||
-- name: GetChatGeneralModelOverride :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id;
|
||||
|
||||
-- name: UpsertChatGeneralModelOverride :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override';
|
||||
|
||||
-- name: GetChatDesktopEnabled :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop;
|
||||
|
||||
+145
-24
@@ -449,7 +449,7 @@ func validateChatPlanMode(mode codersdk.ChatPlanMode) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) {
|
||||
func parseChatModelOverride(raw string) (*uuid.UUID, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
//nolint:nilnil // Empty site-config value means the override is unset.
|
||||
@@ -457,19 +457,29 @@ func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) {
|
||||
}
|
||||
modelConfigID, err := uuid.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse explore model override: %w", err)
|
||||
return nil, xerrors.Errorf("parse chat model override: %w", err)
|
||||
}
|
||||
return &modelConfigID, nil
|
||||
}
|
||||
|
||||
func formatChatExploreModelOverride(id *uuid.UUID) string {
|
||||
func formatChatModelOverride(id *uuid.UUID) string {
|
||||
if id == nil {
|
||||
return ""
|
||||
}
|
||||
return id.String()
|
||||
}
|
||||
|
||||
func validateChatExploreModelOverrideID(
|
||||
func lookupEnabledChatModelConfigByID(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
id uuid.UUID,
|
||||
) (database.ChatModelConfig, error) {
|
||||
//nolint:gocritic // Validation lookup uses AsChatd to check model
|
||||
// availability independently of the caller's read permissions.
|
||||
return db.GetEnabledChatModelConfigByID(dbauthz.AsChatd(ctx), id)
|
||||
}
|
||||
|
||||
func validateChatModelOverrideID(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
id *uuid.UUID,
|
||||
@@ -482,9 +492,7 @@ func validateChatExploreModelOverrideID(
|
||||
Message: "Invalid model_config_id.",
|
||||
}
|
||||
}
|
||||
//nolint:gocritic // Validation lookup uses system context to check model
|
||||
// availability independently of the caller's read permissions.
|
||||
_, err := db.GetEnabledChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), *id)
|
||||
_, err := lookupEnabledChatModelConfigByID(ctx, db, *id)
|
||||
if err == nil {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -499,18 +507,23 @@ func validateChatExploreModelOverrideID(
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) getChatExploreModelOverrideConfig(
|
||||
func (api *API) getChatModelOverrideConfig(
|
||||
ctx context.Context,
|
||||
settingName string,
|
||||
getter func(context.Context) (string, error),
|
||||
) (*uuid.UUID, bool, error) {
|
||||
raw, err := api.Database.GetChatExploreModelOverride(ctx)
|
||||
raw, err := getter(ctx)
|
||||
if err != nil {
|
||||
return nil, false, xerrors.Errorf("get explore model override: %w", err)
|
||||
return nil, false, xerrors.Errorf("get %s model override: %w", settingName, err)
|
||||
}
|
||||
id, err := parseChatExploreModelOverride(raw)
|
||||
id, err := parseChatModelOverride(raw)
|
||||
if err != nil {
|
||||
// Degrade malformed values to unset so the admin settings page
|
||||
// remains accessible and the bad value can be cleared.
|
||||
api.Logger.Warn(ctx, "malformed explore model override in site config, treating as unset",
|
||||
api.Logger.Warn(
|
||||
ctx,
|
||||
"malformed model override in site config, treating as unset",
|
||||
slog.F("setting", settingName),
|
||||
slog.F("raw_value", raw),
|
||||
slog.Error(err),
|
||||
)
|
||||
@@ -519,6 +532,64 @@ func (api *API) getChatExploreModelOverrideConfig(
|
||||
return id, false, nil
|
||||
}
|
||||
|
||||
func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) {
|
||||
overrideContext := codersdk.ChatAgentModelOverrideContext(raw)
|
||||
if overrideContext.Valid() {
|
||||
return overrideContext, nil
|
||||
}
|
||||
return "", xerrors.Errorf("unknown chat agent model override context %q", raw)
|
||||
}
|
||||
|
||||
type chatAgentModelOverrideSiteConfig struct {
|
||||
getter func(context.Context) (string, error)
|
||||
upsert func(context.Context, string) error
|
||||
}
|
||||
|
||||
func (api *API) chatAgentModelOverrideSiteConfig(
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (chatAgentModelOverrideSiteConfig, error) {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
return chatAgentModelOverrideSiteConfig{
|
||||
getter: api.Database.GetChatGeneralModelOverride,
|
||||
upsert: api.Database.UpsertChatGeneralModelOverride,
|
||||
}, nil
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
return chatAgentModelOverrideSiteConfig{
|
||||
getter: api.Database.GetChatExploreModelOverride,
|
||||
upsert: api.Database.UpsertChatExploreModelOverride,
|
||||
}, nil
|
||||
default:
|
||||
return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf(
|
||||
"unknown chat agent model override context %q",
|
||||
overrideContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) getChatAgentModelOverrideConfig(
|
||||
ctx context.Context,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (*uuid.UUID, bool, error) {
|
||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter)
|
||||
}
|
||||
|
||||
func (api *API) upsertChatAgentModelOverrideConfig(
|
||||
ctx context.Context,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
modelConfigID *uuid.UUID,
|
||||
) error {
|
||||
siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -3827,53 +3898,103 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func readChatAgentModelOverrideContext(
|
||||
rw http.ResponseWriter,
|
||||
r *http.Request,
|
||||
) (codersdk.ChatAgentModelOverrideContext, bool) {
|
||||
ctx := r.Context()
|
||||
rawContext := chi.URLParam(r, "context")
|
||||
overrideContext, err := parseChatAgentModelOverrideContext(rawContext)
|
||||
if err == nil {
|
||||
return overrideContext, true
|
||||
}
|
||||
validContextValues := make(
|
||||
[]string,
|
||||
0,
|
||||
len(codersdk.AllChatAgentModelOverrideContexts()),
|
||||
)
|
||||
for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() {
|
||||
validContextValues = append(validContextValues, string(overrideContext))
|
||||
}
|
||||
validContexts := strings.Join(validContextValues, ", ")
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat agent model override context.",
|
||||
Detail: fmt.Sprintf(
|
||||
"Expected one of %s. Got %q.",
|
||||
validContexts,
|
||||
rawContext,
|
||||
),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
modelConfigID, hasMalformedOverride, err := api.getChatExploreModelOverrideConfig(ctx)
|
||||
modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching Explore model override.",
|
||||
Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatExploreModelOverrideResponse{
|
||||
ModelConfigID: modelConfigID,
|
||||
HasMalformedOverride: hasMalformedOverride,
|
||||
})
|
||||
resp := codersdk.ChatAgentModelOverrideResponse{
|
||||
Context: overrideContext,
|
||||
ModelConfigID: formatChatModelOverride(modelConfigID),
|
||||
IsMalformed: isMalformed,
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
overrideContext, ok := readChatAgentModelOverrideContext(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatExploreModelOverrideRequest
|
||||
var req codersdk.UpdateChatAgentModelOverrideRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
status, resp := validateChatExploreModelOverrideID(ctx, api.Database, req.ModelConfigID)
|
||||
modelConfigID, err := parseChatModelOverride(req.ModelConfigID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid model_config_id.",
|
||||
Detail: fmt.Sprintf("Value %q is not a valid UUID.", req.ModelConfigID),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, resp := validateChatModelOverrideID(ctx, api.Database, modelConfigID)
|
||||
if resp != nil {
|
||||
httpapi.Write(ctx, rw, status, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.Database.UpsertChatExploreModelOverride(ctx, formatChatExploreModelOverride(req.ModelConfigID)); err != nil {
|
||||
if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating Explore model override.",
|
||||
Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
|
||||
+310
-132
@@ -1134,43 +1134,63 @@ func TestListChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Insert the chat that will later be pinned directly
|
||||
// into the database with a completed status so we
|
||||
// avoid the background chatd processor entirely.
|
||||
pinnedDBChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "pinned-chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
ClientType: database.ChatClientTypeUi,
|
||||
// Create the chat that will later be pinned. It gets the
|
||||
// earliest updated_at because it is inserted first.
|
||||
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "pinned-chat",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill page 1 with newer chats so the pinned chat
|
||||
// would normally be pushed off the first page
|
||||
// (default limit 50). Insert directly into the
|
||||
// database to avoid spawning 51 background chat
|
||||
// processors, which causes timeouts under -race.
|
||||
// Fill page 1 with newer chats so the pinned chat would
|
||||
// normally be pushed off the first page (default limit 50).
|
||||
const fillerCount = 51
|
||||
fillerChats := make([]codersdk.Chat, 0, fillerCount)
|
||||
for i := range fillerCount {
|
||||
_, insertErr := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("filler-%d", i),
|
||||
Status: database.ChatStatusCompleted,
|
||||
ClientType: database.ChatClientTypeUi,
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("filler-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, insertErr)
|
||||
require.NoError(t, createErr)
|
||||
fillerChats = append(fillerChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach a terminal status so
|
||||
// updated_at is stable before paginating. A single
|
||||
// polling loop checks every chat per tick to avoid
|
||||
// O(N) separate Eventually loops.
|
||||
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
|
||||
pending := make(map[uuid.UUID]struct{}, len(allCreated))
|
||||
for _, c := range allCreated {
|
||||
pending[c.ID] = struct{}{}
|
||||
}
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
|
||||
})
|
||||
if listErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, ch := range all {
|
||||
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
|
||||
delete(pending, ch.ID)
|
||||
}
|
||||
}
|
||||
return len(pending) == 0
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the earliest chat.
|
||||
err = client.UpdateChat(ctx, pinnedDBChat.ID, codersdk.UpdateChatRequest{
|
||||
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1186,11 +1206,11 @@ func TestListChats(t *testing.T) {
|
||||
for _, c := range page1 {
|
||||
page1IDs[c.ID] = struct{}{}
|
||||
}
|
||||
_, found := page1IDs[pinnedDBChat.ID]
|
||||
_, found := page1IDs[pinnedChat.ID]
|
||||
require.True(t, found, "pinned chat should appear on page 1")
|
||||
|
||||
// The pinned chat should be the first item in the list.
|
||||
require.Equal(t, pinnedDBChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
})
|
||||
|
||||
// Test cursor pagination with a mix of pinned and unpinned chats.
|
||||
@@ -9396,6 +9416,44 @@ func createChatModelConfig(t *testing.T, client *codersdk.ExperimentalClient) co
|
||||
return modelConfig
|
||||
}
|
||||
|
||||
func createAdditionalChatModelConfig(
|
||||
t *testing.T,
|
||||
client *codersdk.ExperimentalClient,
|
||||
provider string,
|
||||
model string,
|
||||
) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
contextLimit := int64(4096)
|
||||
isDefault := false
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return modelConfig
|
||||
}
|
||||
|
||||
func createDisabledChatModelConfig(
|
||||
t *testing.T,
|
||||
client *codersdk.ExperimentalClient,
|
||||
provider string,
|
||||
model string,
|
||||
) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
modelConfig := createAdditionalChatModelConfig(t, client, provider, model)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Enabled: ptr.Ref(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return updated
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -9897,129 +9955,249 @@ func TestChatPlanModeInstructions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestChatExploreModelOverride(t *testing.T) {
|
||||
//nolint:tparallel,paralleltest // Setting subtests share per-setting coderdtest instances.
|
||||
func TestChatModelOverrides(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
adminClient, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
defaultModel := createChatModelConfig(t, adminClient)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
createAdditionalModel := func(t *testing.T, model string, enabled bool) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
contextLimit := int64(4096)
|
||||
isDefault := false
|
||||
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: defaultModel.Provider,
|
||||
Model: model,
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if enabled {
|
||||
return modelConfig
|
||||
}
|
||||
updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Enabled: ptr.Ref(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return updated
|
||||
type overrideResponse struct {
|
||||
context codersdk.ChatAgentModelOverrideContext
|
||||
modelConfigID string
|
||||
isMalformed bool
|
||||
}
|
||||
|
||||
t.Run("DefaultGETReturnsEmpty", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
type settingTest struct {
|
||||
name string
|
||||
context codersdk.ChatAgentModelOverrideContext
|
||||
dbGet func(context.Context, database.Store) (string, error)
|
||||
dbUpsert func(context.Context, database.Store, string) error
|
||||
}
|
||||
|
||||
resp, err := adminClient.GetChatExploreModelOverride(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.ModelConfigID)
|
||||
require.False(t, resp.HasMalformedOverride)
|
||||
})
|
||||
settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string {
|
||||
return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext)
|
||||
}
|
||||
|
||||
t.Run("AdminCanSetAndClear", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
overrideModel := createAdditionalModel(t, "gpt-4.1-mini", true)
|
||||
getOverride := func(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (overrideResponse, error) {
|
||||
resp, err := client.GetChatAgentModelOverride(ctx, overrideContext)
|
||||
if err != nil {
|
||||
return overrideResponse{}, err
|
||||
}
|
||||
return overrideResponse{
|
||||
context: resp.Context,
|
||||
modelConfigID: resp.ModelConfigID,
|
||||
isMalformed: resp.IsMalformed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
|
||||
ModelConfigID: &overrideModel.ID,
|
||||
putOverride := func(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
modelConfigID string,
|
||||
) error {
|
||||
return client.UpdateChatAgentModelOverride(
|
||||
ctx,
|
||||
overrideContext,
|
||||
codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID},
|
||||
)
|
||||
}
|
||||
|
||||
settings := []settingTest{
|
||||
{
|
||||
name: "General",
|
||||
context: codersdk.ChatAgentModelOverrideContextGeneral,
|
||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||
return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||
},
|
||||
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
|
||||
return db.UpsertChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Explore",
|
||||
context: codersdk.ChatAgentModelOverrideContextExplore,
|
||||
dbGet: func(ctx context.Context, db database.Store) (string, error) {
|
||||
return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx))
|
||||
},
|
||||
dbUpsert: func(ctx context.Context, db database.Store, value string) error {
|
||||
return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, setting := range settings {
|
||||
t.Run(setting.name, func(t *testing.T) {
|
||||
adminClient, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
defaultModel := createChatModelConfig(t, adminClient)
|
||||
openAIModel := createAdditionalChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
defaultModel.Provider,
|
||||
"gpt-4.1-mini-"+string(setting.context),
|
||||
)
|
||||
disabledModel := createDisabledChatModelConfig(
|
||||
t,
|
||||
adminClient,
|
||||
defaultModel.Provider,
|
||||
"gpt-4.1-disabled-"+string(setting.context),
|
||||
)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
t.Run("DefaultGETReturnsEmpty", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
resp, err := getOverride(ctx, adminClient, setting.context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, setting.context, resp.context)
|
||||
require.Empty(t, resp.modelConfigID)
|
||||
require.False(t, resp.isMalformed)
|
||||
|
||||
raw, err := setting.dbGet(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, raw, "expected empty stored override for %s", settingPath(setting.context))
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetAndClear", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := putOverride(ctx, adminClient, setting.context, openAIModel.ID.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := setting.dbGet(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, openAIModel.ID.String(), raw, "expected stored override for %s", settingPath(setting.context))
|
||||
|
||||
resp, err := getOverride(ctx, adminClient, setting.context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, setting.context, resp.context)
|
||||
require.Equal(t, openAIModel.ID.String(), resp.modelConfigID)
|
||||
require.False(t, resp.isMalformed)
|
||||
|
||||
err = putOverride(ctx, adminClient, setting.context, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err = setting.dbGet(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, raw, "expected cleared override for %s", settingPath(setting.context))
|
||||
|
||||
resp, err = getOverride(ctx, adminClient, setting.context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, setting.context, resp.context)
|
||||
require.Empty(t, resp.modelConfigID)
|
||||
require.False(t, resp.isMalformed)
|
||||
})
|
||||
|
||||
t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
require.NoError(t, setting.dbUpsert(ctx, db, "not-a-uuid"))
|
||||
|
||||
resp, err := getOverride(ctx, adminClient, setting.context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, setting.context, resp.context)
|
||||
require.Empty(t, resp.modelConfigID)
|
||||
require.True(t, resp.isMalformed)
|
||||
|
||||
err = putOverride(ctx, adminClient, setting.context, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := setting.dbGet(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, raw, "expected malformed override to be cleared for %s", settingPath(setting.context))
|
||||
|
||||
resp, err = getOverride(ctx, adminClient, setting.context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, setting.context, resp.context)
|
||||
require.Empty(t, resp.modelConfigID)
|
||||
require.False(t, resp.isMalformed)
|
||||
})
|
||||
|
||||
t.Run("InvalidUUIDReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := putOverride(ctx, adminClient, setting.context, "not-a-uuid")
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
|
||||
require.Equal(t, "Value \"not-a-uuid\" is not a valid UUID.", sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("DisabledModelReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := putOverride(ctx, adminClient, setting.context, disabledModel.ID.String())
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("UnknownModelReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
unknownModelID := uuid.New()
|
||||
|
||||
err := putOverride(ctx, adminClient, setting.context, unknownModelID.String())
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("NonAdminGETReturns404", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := getOverride(ctx, memberClient, setting.context)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("NonAdminPUTReturns403", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := putOverride(ctx, memberClient, setting.context, defaultModel.ID.String())
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
resp, err := adminClient.GetChatExploreModelOverride(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.ModelConfigID)
|
||||
require.Equal(t, overrideModel.ID, *resp.ModelConfigID)
|
||||
require.False(t, resp.HasMalformedOverride)
|
||||
|
||||
err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = adminClient.GetChatExploreModelOverride(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.ModelConfigID)
|
||||
require.False(t, resp.HasMalformedOverride)
|
||||
})
|
||||
|
||||
t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) {
|
||||
t.Run("UnknownContextReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
require.NoError(t, db.UpsertChatExploreModelOverride(
|
||||
dbauthz.AsSystemRestricted(ctx),
|
||||
"not-a-uuid",
|
||||
))
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
||||
|
||||
resp, err := adminClient.GetChatExploreModelOverride(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.ModelConfigID)
|
||||
require.True(t, resp.HasMalformedOverride)
|
||||
|
||||
err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = adminClient.GetChatExploreModelOverride(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, resp.ModelConfigID)
|
||||
require.False(t, resp.HasMalformedOverride)
|
||||
})
|
||||
|
||||
t.Run("DisabledModelReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
disabledModel := createAdditionalModel(t, "gpt-4.1-disabled", false)
|
||||
|
||||
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
|
||||
ModelConfigID: &disabledModel.ID,
|
||||
})
|
||||
_, err := getOverride(ctx, adminClient, unknownContext)
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
|
||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
||||
require.Equal(
|
||||
t,
|
||||
`Expected one of general, explore. Got "not-a-context".`,
|
||||
sdkErr.Detail,
|
||||
)
|
||||
|
||||
err = putOverride(ctx, adminClient, unknownContext, "")
|
||||
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message)
|
||||
require.Equal(
|
||||
t,
|
||||
`Expected one of general, explore. Got "not-a-context".`,
|
||||
sdkErr.Detail,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("UnknownModelReturns400", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
unknownModelID := uuid.New()
|
||||
|
||||
err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
|
||||
ModelConfigID: &unknownModelID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid model_config_id.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("NonAdminGETReturns404", func(t *testing.T) {
|
||||
t.Run("NonAdminUnknownContextUsesAuthResponse", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := memberClient.GetChatExploreModelOverride(ctx)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context")
|
||||
|
||||
_, err := getOverride(ctx, memberClient, unknownContext)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("NonAdminPUTReturns403", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := memberClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{
|
||||
ModelConfigID: &defaultModel.ID,
|
||||
})
|
||||
err = putOverride(ctx, memberClient, unknownContext, "")
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
|
||||
+187
-42
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -27,6 +28,18 @@ import (
|
||||
|
||||
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
|
||||
|
||||
var errInvalidModelOverrideMetadata = xerrors.New("invalid model override metadata")
|
||||
|
||||
type modelOverrideConfigResolver func(
|
||||
context.Context,
|
||||
uuid.UUID,
|
||||
) (database.ChatModelConfig, string, error)
|
||||
|
||||
type modelOverrideProviderKeysResolver func(
|
||||
context.Context,
|
||||
uuid.UUID,
|
||||
) (chatprovider.ProviderAPIKeys, error)
|
||||
|
||||
const (
|
||||
subagentAwaitPollInterval = 200 * time.Millisecond
|
||||
subagentAwaitFallbackPoll = 5 * time.Second
|
||||
@@ -90,66 +103,199 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (p *Server) resolveExploreSubagentModelConfigID(
|
||||
func subagentModelOverrideLogLabel(
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) string {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
return "general delegated child"
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
return "explore"
|
||||
default:
|
||||
return string(overrideContext)
|
||||
}
|
||||
}
|
||||
|
||||
func readSubagentModelOverride(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
fallback uuid.UUID,
|
||||
) (uuid.UUID, error) {
|
||||
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
raw, err := p.db.GetChatExploreModelOverride(chatdCtx)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("get Explore model override: %w", err)
|
||||
}
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
configuredModelConfigID, err := uuid.Parse(trimmed)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx,
|
||||
"invalid Explore model override, falling back to current turn model",
|
||||
slog.F("raw_model_config_id", trimmed),
|
||||
slog.Error(err),
|
||||
db database.Store,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (string, error) {
|
||||
switch overrideContext {
|
||||
case codersdk.ChatAgentModelOverrideContextGeneral:
|
||||
return db.GetChatGeneralModelOverride(ctx)
|
||||
case codersdk.ChatAgentModelOverrideContextExplore:
|
||||
return db.GetChatExploreModelOverride(ctx)
|
||||
default:
|
||||
return "", xerrors.Errorf(
|
||||
"unknown subagent model override context %q",
|
||||
overrideContext,
|
||||
)
|
||||
return fallback, nil
|
||||
}
|
||||
modelConfig, err := p.db.GetEnabledChatModelConfigByID(
|
||||
chatdCtx,
|
||||
configuredModelConfigID,
|
||||
)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
p.logger.Warn(ctx,
|
||||
"explore model override is unavailable, falling back to current turn model",
|
||||
slog.F("model_config_id", configuredModelConfigID),
|
||||
)
|
||||
return fallback, nil
|
||||
}
|
||||
return uuid.Nil, xerrors.Errorf("get enabled chat model config by id: %w", err)
|
||||
}
|
||||
|
||||
func validateModelConfigAndResolveProvider(
|
||||
modelConfig database.ChatModelConfig,
|
||||
) (database.ChatModelConfig, string, error) {
|
||||
if !modelConfig.Enabled {
|
||||
return database.ChatModelConfig{}, "", sql.ErrNoRows
|
||||
}
|
||||
providerName, _, err := chatprovider.ResolveModelWithProviderHint(
|
||||
modelConfig.Model,
|
||||
modelConfig.Provider,
|
||||
)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("resolve Explore model provider: %w", err)
|
||||
return database.ChatModelConfig{}, "", xerrors.Errorf(
|
||||
"%w: %v",
|
||||
errInvalidModelOverrideMetadata,
|
||||
err,
|
||||
)
|
||||
}
|
||||
providerKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID)
|
||||
return modelConfig, providerName, nil
|
||||
}
|
||||
|
||||
func enabledProviderContainsName(
|
||||
providers []database.ChatProvider,
|
||||
providerName string,
|
||||
) bool {
|
||||
normalizedProviderName := chatprovider.NormalizeProvider(providerName)
|
||||
for _, provider := range providers {
|
||||
if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) resolveConfiguredModelOverride(
|
||||
ctx context.Context,
|
||||
overrideContext string,
|
||||
raw string,
|
||||
ownerID uuid.UUID,
|
||||
resolveModelConfig modelOverrideConfigResolver,
|
||||
resolveProviderKeys modelOverrideProviderKeysResolver,
|
||||
) (database.ChatModelConfig, bool, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
configuredModelConfigID, err := uuid.Parse(trimmed)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
p.logger.Info(ctx,
|
||||
"invalid model override, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
slog.F("raw_model_config_id", trimmed),
|
||||
slog.Error(err),
|
||||
)
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
if providerKeys.APIKey(providerName) == "" {
|
||||
p.logger.Warn(ctx,
|
||||
"explore model override credentials are unavailable, falling back to current turn model",
|
||||
modelConfig, providerName, err := resolveModelConfig(
|
||||
ctx,
|
||||
configuredModelConfigID,
|
||||
)
|
||||
if err != nil {
|
||||
switch {
|
||||
case xerrors.Is(err, sql.ErrNoRows):
|
||||
p.logger.Info(ctx,
|
||||
"model override is unavailable, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
slog.F("model_config_id", configuredModelConfigID),
|
||||
)
|
||||
case errors.Is(err, errInvalidModelOverrideMetadata):
|
||||
p.logger.Info(ctx,
|
||||
"model override metadata is invalid, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
slog.F("model_config_id", configuredModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
default:
|
||||
p.logger.Warn(ctx,
|
||||
"failed to resolve model override, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
slog.F("model_config_id", configuredModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
providerKeys, err := resolveProviderKeys(ctx, ownerID)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, false, xerrors.Errorf(
|
||||
"resolve provider API keys: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
if providerKeys.APIKey(providerName) == "" &&
|
||||
!(chatprovider.ProviderAllowsAmbientCredentials(providerName) &&
|
||||
providerKeys.HasProvider(providerName)) {
|
||||
p.logger.Info(ctx,
|
||||
"model override credentials are unavailable, ignoring",
|
||||
slog.F("override_context", overrideContext),
|
||||
slog.F("model_config_id", configuredModelConfigID),
|
||||
slog.F("provider", providerName),
|
||||
)
|
||||
return fallback, nil
|
||||
return database.ChatModelConfig{}, false, nil
|
||||
}
|
||||
return modelConfig, true, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveSubagentModelConfigID(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
overrideContext codersdk.ChatAgentModelOverrideContext,
|
||||
) (uuid.UUID, error) {
|
||||
//nolint:gocritic // Chatd needs its scoped deployment-config read access here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
raw, err := readSubagentModelOverride(chatdCtx, p.db, overrideContext)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf(
|
||||
"get %s model override: %w",
|
||||
subagentModelOverrideLogLabel(overrideContext),
|
||||
err,
|
||||
)
|
||||
}
|
||||
modelConfig, ok, err := p.resolveConfiguredModelOverride(
|
||||
ctx,
|
||||
string(overrideContext),
|
||||
raw,
|
||||
ownerID,
|
||||
p.resolveModelConfigAndNormalizedProvider,
|
||||
p.resolveUserProviderAPIKeys,
|
||||
)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
if !ok {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
return modelConfig.ID, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveModelConfigAndNormalizedProvider(
|
||||
ctx context.Context,
|
||||
modelConfigID uuid.UUID,
|
||||
) (database.ChatModelConfig, string, error) {
|
||||
if modelConfigID == uuid.Nil {
|
||||
return database.ChatModelConfig{}, "", sql.ErrNoRows
|
||||
}
|
||||
modelConfig, err := p.configCache.ModelConfigByID(ctx, modelConfigID)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, "", err
|
||||
}
|
||||
modelConfig, providerName, err := validateModelConfigAndResolveProvider(modelConfig)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, "", err
|
||||
}
|
||||
enabledProviders, err := p.configCache.EnabledProviders(ctx)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, "", err
|
||||
}
|
||||
if !enabledProviderContainsName(enabledProviders, providerName) {
|
||||
return database.ChatModelConfig{}, "", sql.ErrNoRows
|
||||
}
|
||||
return modelConfig, providerName, nil
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(
|
||||
ctx context.Context,
|
||||
currentChat func() database.Chat,
|
||||
@@ -444,7 +590,6 @@ func (p *Server) loadSubagentSpawnParentChat(
|
||||
if err := validateSubagentSpawnParent(parent); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
reloadedParent, err := p.db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to load parent chat for spawn_agent",
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,22 +44,37 @@ func allSubagentDefinitions() []subagentDefinition {
|
||||
{
|
||||
id: subagentTypeGeneral,
|
||||
description: "delegated work that may inspect or modify workspace files",
|
||||
buildOptions: func(_ context.Context, _ *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) {
|
||||
return childSubagentChatOptions{}, nil
|
||||
buildOptions: func(ctx context.Context, p *Server, parent database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) {
|
||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||
ctx,
|
||||
parent.OwnerID,
|
||||
codersdk.ChatAgentModelOverrideContextGeneral,
|
||||
)
|
||||
if err != nil {
|
||||
return childSubagentChatOptions{}, err
|
||||
}
|
||||
options := childSubagentChatOptions{}
|
||||
if modelConfigID != uuid.Nil {
|
||||
options.modelConfigIDOverride = &modelConfigID
|
||||
}
|
||||
return options, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
id: subagentTypeExplore,
|
||||
description: "read-only discovery, code tracing, and system understanding",
|
||||
buildOptions: func(ctx context.Context, p *Server, _ database.Chat, turnParent database.Chat, currentModelConfigID uuid.UUID, _ string) (childSubagentChatOptions, error) {
|
||||
modelConfigID, err := p.resolveExploreSubagentModelConfigID(
|
||||
modelConfigID, err := p.resolveSubagentModelConfigID(
|
||||
ctx,
|
||||
turnParent.OwnerID,
|
||||
currentModelConfigID,
|
||||
codersdk.ChatAgentModelOverrideContextExplore,
|
||||
)
|
||||
if err != nil {
|
||||
return childSubagentChatOptions{}, err
|
||||
}
|
||||
if modelConfigID == uuid.Nil {
|
||||
modelConfigID = currentModelConfigID
|
||||
}
|
||||
inheritedMCPServerIDs, err := p.resolveExploreToolSnapshot(
|
||||
ctx,
|
||||
turnParent,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -71,7 +73,14 @@ func newInternalTestServer(
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) *Server {
|
||||
return newInternalTestServerWithClock(t, db, ps, keys, nil)
|
||||
return newInternalTestServerWithLoggerAndClock(
|
||||
t,
|
||||
db,
|
||||
ps,
|
||||
keys,
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newInternalTestServerWithClock(
|
||||
@@ -80,10 +89,37 @@ func newInternalTestServerWithClock(
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
clk quartz.Clock,
|
||||
) *Server {
|
||||
return newInternalTestServerWithLoggerAndClock(
|
||||
t,
|
||||
db,
|
||||
ps,
|
||||
keys,
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
clk,
|
||||
)
|
||||
}
|
||||
|
||||
func newInternalTestServerWithLogger(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
logger slog.Logger,
|
||||
) *Server {
|
||||
return newInternalTestServerWithLoggerAndClock(t, db, ps, keys, logger, nil)
|
||||
}
|
||||
|
||||
func newInternalTestServerWithLoggerAndClock(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
logger slog.Logger,
|
||||
clk quartz.Clock,
|
||||
) *Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := New(Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
@@ -101,6 +137,35 @@ func newInternalTestServerWithClock(
|
||||
return server
|
||||
}
|
||||
|
||||
type subagentTestLogSink struct {
|
||||
mu sync.Mutex
|
||||
entries []slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *subagentTestLogSink) LogEntry(_ context.Context, entry slog.SinkEntry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.entries = append(s.entries, entry)
|
||||
}
|
||||
|
||||
func (*subagentTestLogSink) Sync() {}
|
||||
|
||||
func (s *subagentTestLogSink) entriesAtLevelWithMessage(
|
||||
level slog.Level,
|
||||
message string,
|
||||
) []slog.SinkEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := make([]slog.SinkEntry, 0, len(s.entries))
|
||||
for _, entry := range s.entries {
|
||||
if entry.Level == level && entry.Message == message {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// seedInternalChatDeps inserts an OpenAI provider and model config
|
||||
// into the database and returns the created user, organization,
|
||||
// and model. This deliberately does NOT create an Anthropic
|
||||
@@ -218,6 +283,54 @@ func insertInternalChatModelConfig(
|
||||
userID uuid.UUID,
|
||||
model string,
|
||||
enabled bool,
|
||||
) database.ChatModelConfig {
|
||||
return insertInternalChatModelConfigForProvider(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
userID,
|
||||
"openai",
|
||||
model,
|
||||
enabled,
|
||||
)
|
||||
}
|
||||
|
||||
func insertInternalChatProvider(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
provider string,
|
||||
apiKey string,
|
||||
centralAPIKeyEnabled bool,
|
||||
allowUserAPIKey bool,
|
||||
allowCentralAPIKeyFallback bool,
|
||||
) database.ChatProvider {
|
||||
t.Helper()
|
||||
|
||||
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: provider,
|
||||
DisplayName: provider,
|
||||
APIKey: apiKey,
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: centralAPIKeyEnabled,
|
||||
AllowUserApiKey: allowUserAPIKey,
|
||||
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return providerConfig
|
||||
}
|
||||
|
||||
func insertInternalChatModelConfigForProvider(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
provider string,
|
||||
model string,
|
||||
enabled bool,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
return insertInternalChatModelConfigWithOptions(
|
||||
@@ -225,6 +338,7 @@ func insertInternalChatModelConfig(
|
||||
t,
|
||||
db,
|
||||
userID,
|
||||
provider,
|
||||
model,
|
||||
enabled,
|
||||
json.RawMessage(`{}`),
|
||||
@@ -236,6 +350,7 @@ func insertInternalChatModelConfigWithOptions(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
provider string,
|
||||
model string,
|
||||
enabled bool,
|
||||
options json.RawMessage,
|
||||
@@ -243,7 +358,7 @@ func insertInternalChatModelConfigWithOptions(
|
||||
t.Helper()
|
||||
|
||||
modelConfig, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
||||
@@ -574,6 +689,213 @@ func TestSpawnAgent_GeneralInheritsParentModelWhenOmitted(t *testing.T) {
|
||||
require.Equal(t, parentChat.LastModelConfigID, childChat.LastModelConfigID)
|
||||
}
|
||||
|
||||
func TestSpawnAgent_GeneralUsesConfiguredModelOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, org, model := seedInternalChatDeps(ctx, t, db)
|
||||
overrideModel := insertInternalChatModelConfig(
|
||||
ctx, t, db, user.ID, "general-override-"+uuid.NewString(), true,
|
||||
)
|
||||
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
|
||||
parentChat := createInternalParentChat(
|
||||
ctx, t, server, db, org.ID, user.ID, model.ID, "parent-general-override",
|
||||
)
|
||||
|
||||
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
|
||||
Type: subagentTypeGeneral,
|
||||
Prompt: "delegate general work",
|
||||
})
|
||||
childID := requireSpawnAgentChildChatID(t, resp)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, overrideModel.ID, childChat.LastModelConfigID)
|
||||
require.False(t, childChat.PlanMode.Valid)
|
||||
}
|
||||
|
||||
func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenCredentialsUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
logSink := &subagentTestLogSink{}
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
|
||||
server := newInternalTestServerWithLogger(t, db, ps, chatprovider.ProviderAPIKeys{}, logger)
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, org, model := seedInternalChatDeps(ctx, t, db)
|
||||
insertInternalChatProvider(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
user.ID,
|
||||
"openai-compat",
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
overrideModel := insertInternalChatModelConfigForProvider(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
user.ID,
|
||||
"openai-compat",
|
||||
"gpt-4o-mini",
|
||||
true,
|
||||
)
|
||||
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-general-credentials-fallback",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("delegate work"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
|
||||
Type: subagentTypeGeneral,
|
||||
Prompt: "inspect provider credentials",
|
||||
})
|
||||
childID := requireSpawnAgentChildChatID(t, resp)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, model.ID, childChat.LastModelConfigID)
|
||||
require.False(t, childChat.PlanMode.Valid)
|
||||
require.Len(t, logSink.entriesAtLevelWithMessage(
|
||||
slog.LevelInfo,
|
||||
"model override credentials are unavailable, ignoring",
|
||||
), 1)
|
||||
}
|
||||
|
||||
func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenProviderDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
logSink := &subagentTestLogSink{}
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
|
||||
server := newInternalTestServerWithLogger(
|
||||
t,
|
||||
db,
|
||||
ps,
|
||||
chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
"openai-compat": "fallback-key",
|
||||
},
|
||||
},
|
||||
logger,
|
||||
)
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, org, model := seedInternalChatDeps(ctx, t, db)
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai-compat",
|
||||
DisplayName: "openai-compat",
|
||||
APIKey: "",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: false,
|
||||
CentralApiKeyEnabled: false,
|
||||
AllowUserApiKey: true,
|
||||
AllowCentralApiKeyFallback: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
overrideModel := insertInternalChatModelConfigForProvider(
|
||||
ctx,
|
||||
t,
|
||||
db,
|
||||
user.ID,
|
||||
"openai-compat",
|
||||
"gpt-4o-mini",
|
||||
true,
|
||||
)
|
||||
require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String()))
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-general-disabled-provider-fallback",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("delegate work"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{
|
||||
Type: subagentTypeGeneral,
|
||||
Prompt: "inspect disabled providers",
|
||||
})
|
||||
childID := requireSpawnAgentChildChatID(t, resp)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, model.ID, childChat.LastModelConfigID)
|
||||
require.False(t, childChat.PlanMode.Valid)
|
||||
require.Len(t, logSink.entriesAtLevelWithMessage(
|
||||
slog.LevelInfo,
|
||||
"model override is unavailable, ignoring",
|
||||
), 1)
|
||||
}
|
||||
|
||||
func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider(
|
||||
t *testing.T,
|
||||
) {
|
||||
t.Parallel()
|
||||
|
||||
logSink := &subagentTestLogSink{}
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
|
||||
server := &Server{logger: logger}
|
||||
ctx := chatdTestContext(t)
|
||||
ownerID := uuid.New()
|
||||
modelConfig := database.ChatModelConfig{
|
||||
ID: uuid.New(),
|
||||
Provider: "bedrock",
|
||||
Model: "anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
DisplayName: "Ambient Bedrock Override",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
resolvedModelConfig, ok, err := server.resolveConfiguredModelOverride(
|
||||
ctx,
|
||||
"plan",
|
||||
modelConfig.ID.String(),
|
||||
ownerID,
|
||||
func(
|
||||
_ context.Context,
|
||||
configuredModelConfigID uuid.UUID,
|
||||
) (database.ChatModelConfig, string, error) {
|
||||
require.Equal(t, modelConfig.ID, configuredModelConfigID)
|
||||
return modelConfig, "bedrock", nil
|
||||
},
|
||||
func(
|
||||
_ context.Context,
|
||||
resolvedOwnerID uuid.UUID,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
require.Equal(t, ownerID, resolvedOwnerID)
|
||||
return chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{"bedrock": ""},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, modelConfig, resolvedModelConfig)
|
||||
require.Empty(t, logSink.entriesAtLevelWithMessage(
|
||||
slog.LevelInfo,
|
||||
"model override credentials are unavailable, ignoring",
|
||||
))
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChat_OverrideWorksWhenParentHasNoModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1328,7 +1650,6 @@ func TestSubagentLifecycleToolsIncludePersistedSubagentTypeAcrossVariants(t *tes
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1450,7 +1771,6 @@ func TestSubagentLifecycleToolErrorsIncludePersistedSubagentType(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+57
-22
@@ -562,19 +562,46 @@ type UpdateChatPlanModeInstructionsRequest struct {
|
||||
PlanModeInstructions string `json:"plan_mode_instructions"`
|
||||
}
|
||||
|
||||
// ChatExploreModelOverrideResponse is the response body for the Explore
|
||||
// subagent model override configuration endpoint.
|
||||
type ChatExploreModelOverrideResponse struct {
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
// HasMalformedOverride reports whether the saved override is malformed and
|
||||
// is currently being treated as unset.
|
||||
HasMalformedOverride bool `json:"has_malformed_override"`
|
||||
// ChatAgentModelOverrideContext identifies which chat or subagent context
|
||||
// a deployment override applies to.
|
||||
type ChatAgentModelOverrideContext string
|
||||
|
||||
const (
|
||||
ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general"
|
||||
ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore"
|
||||
)
|
||||
|
||||
// Valid reports whether the override context is one of the supported values.
|
||||
func (c ChatAgentModelOverrideContext) Valid() bool {
|
||||
switch c {
|
||||
case ChatAgentModelOverrideContextGeneral,
|
||||
ChatAgentModelOverrideContextExplore:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateChatExploreModelOverrideRequest is the request body for updating the
|
||||
// Explore subagent model override configuration endpoint.
|
||||
type UpdateChatExploreModelOverrideRequest struct {
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
// AllChatAgentModelOverrideContexts returns all supported override contexts.
|
||||
func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext {
|
||||
return []ChatAgentModelOverrideContext{
|
||||
ChatAgentModelOverrideContextGeneral,
|
||||
ChatAgentModelOverrideContextExplore,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatAgentModelOverrideResponse is the response body for the chat agent
|
||||
// model override configuration endpoint.
|
||||
type ChatAgentModelOverrideResponse struct {
|
||||
Context ChatAgentModelOverrideContext `json:"context"`
|
||||
ModelConfigID string `json:"model_config_id"`
|
||||
IsMalformed bool `json:"is_malformed"`
|
||||
}
|
||||
|
||||
// UpdateChatAgentModelOverrideRequest is the request body for updating the
|
||||
// chat agent model override configuration endpoint.
|
||||
type UpdateChatAgentModelOverrideRequest struct {
|
||||
ModelConfigID string `json:"model_config_id"`
|
||||
}
|
||||
|
||||
// UserChatCustomPrompt is the request and response body for the
|
||||
@@ -2024,25 +2051,33 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatExploreModelOverride returns the deployment-wide Explore subagent
|
||||
// model override.
|
||||
func (c *ExperimentalClient) GetChatExploreModelOverride(ctx context.Context) (ChatExploreModelOverrideResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/explore-model-override", nil)
|
||||
// GetChatAgentModelOverride returns the deployment-wide chat agent model
|
||||
// override for the requested context.
|
||||
func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) {
|
||||
path := fmt.Sprintf(
|
||||
"/api/experimental/chats/config/agent-model-override/%s",
|
||||
url.PathEscape(string(override)),
|
||||
)
|
||||
res, err := c.Request(ctx, http.MethodGet, path, nil)
|
||||
if err != nil {
|
||||
return ChatExploreModelOverrideResponse{}, err
|
||||
return ChatAgentModelOverrideResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatExploreModelOverrideResponse{}, ReadBodyAsError(res)
|
||||
return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatExploreModelOverrideResponse
|
||||
var resp ChatAgentModelOverrideResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateChatExploreModelOverride updates the deployment-wide Explore subagent
|
||||
// model override.
|
||||
func (c *ExperimentalClient) UpdateChatExploreModelOverride(ctx context.Context, req UpdateChatExploreModelOverrideRequest) error {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/explore-model-override", req)
|
||||
// UpdateChatAgentModelOverride updates the deployment-wide chat agent model
|
||||
// override for the requested context.
|
||||
func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error {
|
||||
path := fmt.Sprintf(
|
||||
"/api/experimental/chats/config/agent-model-override/%s",
|
||||
url.PathEscape(string(override)),
|
||||
)
|
||||
res, err := c.Request(ctx, http.MethodPut, path, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+13
-11
@@ -3257,20 +3257,22 @@ class ExperimentalApiMethods {
|
||||
);
|
||||
};
|
||||
|
||||
getChatExploreModelOverride =
|
||||
async (): Promise<TypesGen.ChatExploreModelOverrideResponse> => {
|
||||
const response =
|
||||
await this.axios.get<TypesGen.ChatExploreModelOverrideResponse>(
|
||||
"/api/experimental/chats/config/explore-model-override",
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
getChatAgentModelOverride = async (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
): Promise<TypesGen.ChatAgentModelOverrideResponse> => {
|
||||
const response =
|
||||
await this.axios.get<TypesGen.ChatAgentModelOverrideResponse>(
|
||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
updateChatExploreModelOverride = async (
|
||||
req: TypesGen.UpdateChatExploreModelOverrideRequest,
|
||||
updateChatAgentModelOverride = async (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
||||
): Promise<void> => {
|
||||
await this.axios.put(
|
||||
"/api/experimental/chats/config/explore-model-override",
|
||||
`/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`,
|
||||
req,
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1127,23 +1127,6 @@ export const updateChatPlanModeInstructions = (queryClient: QueryClient) => ({
|
||||
},
|
||||
});
|
||||
|
||||
const chatExploreModelOverrideKey = ["chat-explore-model-override"] as const;
|
||||
|
||||
export const chatExploreModelOverride = () => ({
|
||||
queryKey: chatExploreModelOverrideKey,
|
||||
queryFn: () => API.experimental.getChatExploreModelOverride(),
|
||||
});
|
||||
|
||||
export const updateChatExploreModelOverride = (queryClient: QueryClient) => ({
|
||||
mutationFn: (req: TypesGen.UpdateChatExploreModelOverrideRequest) =>
|
||||
API.experimental.updateChatExploreModelOverride(req),
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: chatExploreModelOverrideKey,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const chatDesktopEnabledKey = ["chat-desktop-enabled"] as const;
|
||||
|
||||
export const chatDesktopEnabled = () => ({
|
||||
|
||||
Generated
+28
-23
@@ -1272,6 +1272,25 @@ export interface Chat {
|
||||
readonly children: readonly Chat[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatAgentModelOverrideContext = "explore" | "general";
|
||||
|
||||
export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [
|
||||
"explore",
|
||||
"general",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatAgentModelOverrideResponse is the response body for the chat agent
|
||||
* model override configuration endpoint.
|
||||
*/
|
||||
export interface ChatAgentModelOverrideResponse {
|
||||
readonly context: ChatAgentModelOverrideContext;
|
||||
readonly model_config_id: string;
|
||||
readonly is_malformed: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatBusyBehavior = "interrupt" | "queue";
|
||||
|
||||
@@ -1590,20 +1609,6 @@ export interface ChatDiffStatus {
|
||||
readonly stale_at?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatExploreModelOverrideResponse is the response body for the Explore
|
||||
* subagent model override configuration endpoint.
|
||||
*/
|
||||
export interface ChatExploreModelOverrideResponse {
|
||||
readonly model_config_id?: string;
|
||||
/**
|
||||
* HasMalformedOverride reports whether the saved override is malformed and
|
||||
* is currently being treated as unset.
|
||||
*/
|
||||
readonly has_malformed_override: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatFileMetadata contains lightweight metadata about a file
|
||||
@@ -7657,6 +7662,15 @@ export interface UpdateAppearanceConfig {
|
||||
readonly announcement_banners: readonly BannerConfig[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatAgentModelOverrideRequest is the request body for updating the
|
||||
* chat agent model override configuration endpoint.
|
||||
*/
|
||||
export interface UpdateChatAgentModelOverrideRequest {
|
||||
readonly model_config_id: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatDebugLoggingAllowUsersRequest is the admin request to
|
||||
@@ -7674,15 +7688,6 @@ export interface UpdateChatDesktopEnabledRequest {
|
||||
readonly enable_desktop: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatExploreModelOverrideRequest is the request body for updating the
|
||||
* Explore subagent model override configuration endpoint.
|
||||
*/
|
||||
export interface UpdateChatExploreModelOverrideRequest {
|
||||
readonly model_config_id?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatModelConfigRequest updates a chat model config.
|
||||
|
||||
@@ -1,34 +1,83 @@
|
||||
import type { FC } from "react";
|
||||
import { useMutation, useQuery, useQueryClient } from "react-query";
|
||||
import {
|
||||
chatExploreModelOverride,
|
||||
chatModelConfigs,
|
||||
updateChatExploreModelOverride,
|
||||
} from "#/api/queries/chats";
|
||||
type QueryClient,
|
||||
useMutation,
|
||||
useQuery,
|
||||
useQueryClient,
|
||||
} from "react-query";
|
||||
import { API } from "#/api/api";
|
||||
import { chatModelConfigs } from "#/api/queries/chats";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { useAuthenticated } from "#/hooks/useAuthenticated";
|
||||
import { RequirePermission } from "#/modules/permissions/RequirePermission";
|
||||
import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView";
|
||||
|
||||
const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
||||
"general";
|
||||
const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext =
|
||||
"explore";
|
||||
|
||||
const chatAgentModelOverrideKey = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
) => ["chat-agent-model-override", context] as const;
|
||||
|
||||
const chatAgentModelOverrideQuery = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
) => ({
|
||||
queryKey: chatAgentModelOverrideKey(context),
|
||||
queryFn: () => API.experimental.getChatAgentModelOverride(context),
|
||||
});
|
||||
|
||||
const updateChatAgentModelOverrideMutation = (
|
||||
queryClient: QueryClient,
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
) => ({
|
||||
mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) =>
|
||||
API.experimental.updateChatAgentModelOverride(context, req),
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: chatAgentModelOverrideKey(context),
|
||||
exact: true,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const AgentSettingsAgentsPage: FC = () => {
|
||||
const { permissions } = useAuthenticated();
|
||||
const queryClient = useQueryClient();
|
||||
const canEditDeploymentConfig = permissions.editDeploymentConfig;
|
||||
|
||||
const generalModelOverrideQuery = useQuery({
|
||||
...chatAgentModelOverrideQuery(generalOverrideContext),
|
||||
enabled: canEditDeploymentConfig,
|
||||
});
|
||||
const exploreModelOverrideQuery = useQuery({
|
||||
...chatExploreModelOverride(),
|
||||
enabled: permissions.editDeploymentConfig,
|
||||
...chatAgentModelOverrideQuery(exploreOverrideContext),
|
||||
enabled: canEditDeploymentConfig,
|
||||
});
|
||||
const modelConfigsQuery = useQuery(chatModelConfigs());
|
||||
const saveGeneralModelOverrideMutation = useMutation(
|
||||
updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext),
|
||||
);
|
||||
const saveExploreModelOverrideMutation = useMutation(
|
||||
updateChatExploreModelOverride(queryClient),
|
||||
updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext),
|
||||
);
|
||||
|
||||
return (
|
||||
<RequirePermission isFeatureVisible={permissions.editDeploymentConfig}>
|
||||
<RequirePermission isFeatureVisible={canEditDeploymentConfig}>
|
||||
<AgentSettingsAgentsPageView
|
||||
generalModelOverrideData={generalModelOverrideQuery.data}
|
||||
exploreModelOverrideData={exploreModelOverrideQuery.data}
|
||||
modelConfigsData={modelConfigsQuery.data}
|
||||
modelConfigsError={modelConfigsQuery.error}
|
||||
isLoadingModelConfigs={modelConfigsQuery.isLoading}
|
||||
onSaveGeneralModelOverride={saveGeneralModelOverrideMutation.mutate}
|
||||
isSavingGeneralModelOverride={
|
||||
saveGeneralModelOverrideMutation.isPending
|
||||
}
|
||||
isSaveGeneralModelOverrideError={
|
||||
saveGeneralModelOverrideMutation.isError
|
||||
}
|
||||
onSaveExploreModelOverride={saveExploreModelOverrideMutation.mutate}
|
||||
isSavingExploreModelOverride={
|
||||
saveExploreModelOverrideMutation.isPending
|
||||
|
||||
@@ -6,208 +6,294 @@ import {
|
||||
type AgentSettingsAgentsPageViewProps,
|
||||
} from "./AgentSettingsAgentsPageView";
|
||||
|
||||
const baseArgs: AgentSettingsAgentsPageViewProps = {
|
||||
exploreModelOverrideData: {
|
||||
has_malformed_override: false,
|
||||
},
|
||||
modelConfigsData: [],
|
||||
const OVERRIDE_MALFORMED_WARNING =
|
||||
"The saved override is malformed and is being treated as unset. Click Save to clear it.";
|
||||
const UNAVAILABLE_SAVED_MODEL_WARNING =
|
||||
"The saved model is no longer enabled and will be ignored until you choose a new override.";
|
||||
|
||||
const buildModelConfig = (
|
||||
overrides: Partial<TypesGen.ChatModelConfig>,
|
||||
): TypesGen.ChatModelConfig => ({
|
||||
id: "model-default",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-mini",
|
||||
display_name: "GPT 4.1 Mini",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 1_000_000,
|
||||
compression_threshold: 70,
|
||||
created_at: "2026-03-12T12:00:00.000Z",
|
||||
updated_at: "2026-03-12T12:00:00.000Z",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const buildOverrideData = (
|
||||
context: TypesGen.ChatAgentModelOverrideContext,
|
||||
overrides: Partial<TypesGen.ChatAgentModelOverrideResponse> = {},
|
||||
): TypesGen.ChatAgentModelOverrideResponse => ({
|
||||
context,
|
||||
model_config_id: "",
|
||||
is_malformed: false,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const generalModelConfig = buildModelConfig({
|
||||
id: "model-general-gpt-4.1-mini",
|
||||
display_name: "GPT 4.1 Mini",
|
||||
});
|
||||
|
||||
const claudeSonnetModelConfig = buildModelConfig({
|
||||
id: "model-claude-sonnet-4",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4",
|
||||
display_name: "Claude Sonnet 4",
|
||||
context_limit: 200_000,
|
||||
});
|
||||
|
||||
const exploreFallbackModelConfig = buildModelConfig({
|
||||
id: "model-explore-blank-display",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
display_name: "",
|
||||
context_limit: 200_000,
|
||||
});
|
||||
|
||||
const generalDisabledModelConfig = buildModelConfig({
|
||||
id: "model-general-disabled",
|
||||
model: "gpt-4.1-legacy",
|
||||
display_name: "GPT 4.1 Legacy",
|
||||
enabled: false,
|
||||
});
|
||||
|
||||
const exploreDisabledModelConfig = buildModelConfig({
|
||||
id: "model-explore-disabled",
|
||||
provider: "anthropic",
|
||||
model: "claude-haiku-legacy",
|
||||
display_name: "Claude Haiku Legacy",
|
||||
enabled: false,
|
||||
context_limit: 200_000,
|
||||
});
|
||||
|
||||
const allModelConfigs: TypesGen.ChatModelConfig[] = [
|
||||
generalModelConfig,
|
||||
claudeSonnetModelConfig,
|
||||
exploreFallbackModelConfig,
|
||||
generalDisabledModelConfig,
|
||||
exploreDisabledModelConfig,
|
||||
];
|
||||
|
||||
const makeArgs = (
|
||||
overrides: Partial<AgentSettingsAgentsPageViewProps> = {},
|
||||
): AgentSettingsAgentsPageViewProps => ({
|
||||
generalModelOverrideData: buildOverrideData("general"),
|
||||
exploreModelOverrideData: buildOverrideData("explore"),
|
||||
modelConfigsData: allModelConfigs,
|
||||
modelConfigsError: undefined,
|
||||
isLoadingModelConfigs: false,
|
||||
onSaveGeneralModelOverride: fn(),
|
||||
isSavingGeneralModelOverride: false,
|
||||
isSaveGeneralModelOverrideError: false,
|
||||
onSaveExploreModelOverride: fn(),
|
||||
isSavingExploreModelOverride: false,
|
||||
isSaveExploreModelOverrideError: false,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const getSection = async (
|
||||
canvasElement: HTMLElement,
|
||||
headingName: string,
|
||||
): Promise<HTMLElement> => {
|
||||
const canvas = within(canvasElement);
|
||||
const heading = await canvas.findByRole("heading", { name: headingName });
|
||||
const section = heading.closest("section");
|
||||
if (!(section instanceof HTMLElement)) {
|
||||
throw new Error(
|
||||
`Expected ${headingName} heading to live inside a section.`,
|
||||
);
|
||||
}
|
||||
return section;
|
||||
};
|
||||
|
||||
const selectModelInSection = async (
|
||||
section: HTMLElement,
|
||||
canvasElement: HTMLElement,
|
||||
currentSelectionName: string | RegExp,
|
||||
optionName: string,
|
||||
) => {
|
||||
const trigger = within(section).getByRole("combobox", {
|
||||
name: currentSelectionName,
|
||||
});
|
||||
await userEvent.click(trigger);
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(await body.findByRole("option", { name: optionName }));
|
||||
};
|
||||
|
||||
const meta = {
|
||||
title: "pages/AgentsPage/AgentSettingsAgentsPageView",
|
||||
component: AgentSettingsAgentsPageView,
|
||||
args: baseArgs,
|
||||
args: makeArgs(),
|
||||
} satisfies Meta<typeof AgentSettingsAgentsPageView>;
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof AgentSettingsAgentsPageView>;
|
||||
|
||||
export const ExploreModelOverrideSetting: Story = {
|
||||
args: {
|
||||
exploreModelOverrideData: {
|
||||
model_config_id: "model-explore-1",
|
||||
has_malformed_override: false,
|
||||
},
|
||||
modelConfigsData: [
|
||||
{
|
||||
id: "model-explore-1",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-mini",
|
||||
display_name: "GPT 4.1 Mini",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 1_000_000,
|
||||
compression_threshold: 70,
|
||||
created_at: "2026-03-12T12:00:00.000Z",
|
||||
updated_at: "2026-03-12T12:00:00.000Z",
|
||||
},
|
||||
{
|
||||
id: "model-explore-2",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4",
|
||||
display_name: "Claude Sonnet 4",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 200_000,
|
||||
compression_threshold: 70,
|
||||
created_at: "2026-03-12T12:00:00.000Z",
|
||||
updated_at: "2026-03-12T12:00:00.000Z",
|
||||
},
|
||||
] as TypesGen.ChatModelConfig[],
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
await canvas.findByText("Agents");
|
||||
await canvas.findByText("Explore subagent model");
|
||||
const trigger = canvas.getByRole("combobox", {
|
||||
name: /gpt 4.1 mini/i,
|
||||
});
|
||||
await userEvent.click(trigger);
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(
|
||||
await body.findByRole("option", { name: "Claude Sonnet 4" }),
|
||||
);
|
||||
const form = trigger.closest("form");
|
||||
if (!(form instanceof HTMLFormElement)) {
|
||||
throw new Error("Expected Explore model selector to live inside a form.");
|
||||
}
|
||||
const saveButton = within(form).getByRole("button", { name: "Save" });
|
||||
await waitFor(() => {
|
||||
expect(saveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(saveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
|
||||
{ model_config_id: "model-explore-2" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const ExploreModelOverrideAllowsExplicitClear: Story = {
|
||||
args: {
|
||||
exploreModelOverrideData: {
|
||||
model_config_id: "model-explore-clear",
|
||||
has_malformed_override: false,
|
||||
},
|
||||
modelConfigsData: [
|
||||
{
|
||||
id: "model-explore-clear",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-mini",
|
||||
display_name: "GPT 4.1 Mini",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 1_000_000,
|
||||
compression_threshold: 70,
|
||||
created_at: "2026-03-12T12:00:00.000Z",
|
||||
updated_at: "2026-03-12T12:00:00.000Z",
|
||||
},
|
||||
] as TypesGen.ChatModelConfig[],
|
||||
onSaveExploreModelOverride: fn(),
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const clearButton = await canvas.findByRole("button", { name: "Clear" });
|
||||
const form = clearButton.closest("form");
|
||||
if (!(form instanceof HTMLFormElement)) {
|
||||
throw new Error(
|
||||
"Expected Explore model clear button to live inside a form.",
|
||||
);
|
||||
}
|
||||
|
||||
const saveButton = within(form).getByRole("button", { name: "Save" });
|
||||
await userEvent.click(clearButton);
|
||||
expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
expect(saveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(saveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
|
||||
{},
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const ExploreModelOverrideClearsMalformedSavedValue: Story = {
|
||||
args: {
|
||||
exploreModelOverrideData: {
|
||||
has_malformed_override: true,
|
||||
},
|
||||
modelConfigsData: [],
|
||||
onSaveExploreModelOverride: fn(),
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const canvas = within(canvasElement);
|
||||
await canvas.findByText(
|
||||
"The saved override is malformed and is being treated as unset. Click Save to clear it.",
|
||||
);
|
||||
const clearButton = await canvas.findByRole("button", { name: "Clear" });
|
||||
const form = clearButton.closest("form");
|
||||
if (!(form instanceof HTMLFormElement)) {
|
||||
throw new Error(
|
||||
"Expected Explore model clear button to live inside a form.",
|
||||
);
|
||||
}
|
||||
|
||||
const saveButton = within(form).getByRole("button", { name: "Save" });
|
||||
await waitFor(() => {
|
||||
expect(saveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(clearButton);
|
||||
expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled();
|
||||
await userEvent.click(saveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
|
||||
{},
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const ExploreModelOverrideFallsBackToModelName: Story = {
|
||||
args: {
|
||||
exploreModelOverrideData: {
|
||||
model_config_id: "model-explore-empty-name",
|
||||
has_malformed_override: false,
|
||||
},
|
||||
modelConfigsData: [
|
||||
{
|
||||
id: "model-explore-empty-name",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
display_name: "",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 200_000,
|
||||
compression_threshold: 70,
|
||||
created_at: "2026-03-12T12:00:00.000Z",
|
||||
updated_at: "2026-03-12T12:00:00.000Z",
|
||||
},
|
||||
] as TypesGen.ChatModelConfig[],
|
||||
},
|
||||
export const AllOverridesUnset: Story = {
|
||||
args: makeArgs(),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const trigger = await canvas.findByRole("combobox", {
|
||||
name: /claude-sonnet-4-20250514/i,
|
||||
});
|
||||
expect(trigger).toHaveTextContent("claude-sonnet-4-20250514");
|
||||
await userEvent.click(trigger);
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
expect(
|
||||
await body.findByRole("option", {
|
||||
name: "claude-sonnet-4-20250514",
|
||||
}),
|
||||
).toBeInTheDocument();
|
||||
await canvas.findByText("Agents");
|
||||
|
||||
const headings = await canvas.findAllByRole("heading", { level: 3 });
|
||||
expect(headings.map((heading) => heading.textContent?.trim())).toEqual([
|
||||
"General model",
|
||||
"Explore subagent model",
|
||||
]);
|
||||
|
||||
for (const headingName of ["General model", "Explore subagent model"]) {
|
||||
const section = await getSection(canvasElement, headingName);
|
||||
expect(
|
||||
within(section).getByRole("combobox", { name: "Use chat default" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(section).getByRole("button", { name: "Save" }),
|
||||
).toBeDisabled();
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
export const EachOverrideSetToEnabledModel: Story = {
|
||||
args: makeArgs({
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
model_config_id: generalModelConfig.id,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
model_config_id: exploreFallbackModelConfig.id,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
);
|
||||
|
||||
expect(
|
||||
within(exploreSection).getByRole("combobox", {
|
||||
name: /claude-sonnet-4-20250514/i,
|
||||
}),
|
||||
).toHaveTextContent("claude-sonnet-4-20250514");
|
||||
|
||||
await selectModelInSection(
|
||||
generalSection,
|
||||
canvasElement,
|
||||
/gpt 4\.1 mini/i,
|
||||
"Claude Sonnet 4",
|
||||
);
|
||||
const generalSaveButton = within(generalSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(generalSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(generalSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith(
|
||||
{ model_config_id: claudeSonnetModelConfig.id },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
const exploreClearButton = within(exploreSection).getByRole("button", {
|
||||
name: "Clear",
|
||||
});
|
||||
await userEvent.click(exploreClearButton);
|
||||
const exploreSaveButton = within(exploreSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(exploreSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(exploreSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
|
||||
{ model_config_id: "" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const MalformedOverridesRemainClearableAndSaveable: Story = {
|
||||
args: makeArgs({
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
is_malformed: true,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
is_malformed: true,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
);
|
||||
|
||||
for (const section of [generalSection, exploreSection]) {
|
||||
await within(section).findByText(OVERRIDE_MALFORMED_WARNING);
|
||||
}
|
||||
|
||||
const generalSaveButton = within(generalSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(generalSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(generalSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith(
|
||||
{ model_config_id: "" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
const exploreSaveButton = within(exploreSection).getByRole("button", {
|
||||
name: "Save",
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(exploreSaveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(exploreSaveButton);
|
||||
await waitFor(() => {
|
||||
expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith(
|
||||
{ model_config_id: "" },
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const UnavailableSavedModels: Story = {
|
||||
args: makeArgs({
|
||||
generalModelOverrideData: buildOverrideData("general", {
|
||||
model_config_id: generalDisabledModelConfig.id,
|
||||
}),
|
||||
exploreModelOverrideData: buildOverrideData("explore", {
|
||||
model_config_id: exploreDisabledModelConfig.id,
|
||||
}),
|
||||
}),
|
||||
play: async ({ canvasElement }) => {
|
||||
const generalSection = await getSection(canvasElement, "General model");
|
||||
const exploreSection = await getSection(
|
||||
canvasElement,
|
||||
"Explore subagent model",
|
||||
);
|
||||
|
||||
for (const section of [generalSection, exploreSection]) {
|
||||
await within(section).findByText(UNAVAILABLE_SAVED_MODEL_WARNING);
|
||||
expect(
|
||||
within(section).getByRole("combobox", { name: "Unavailable model" }),
|
||||
).toBeInTheDocument();
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
import type { FC } from "react";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { ExploreModelOverrideSettings } from "./components/ExploreModelOverrideSettings";
|
||||
import { SectionHeader } from "./components/SectionHeader";
|
||||
import {
|
||||
type MutationCallbacks,
|
||||
SubagentModelOverrideSettings,
|
||||
} from "./components/SubagentModelOverrideSettings";
|
||||
|
||||
interface MutationCallbacks {
|
||||
onSuccess?: () => void;
|
||||
onError?: () => void;
|
||||
}
|
||||
type SaveModelOverride = (
|
||||
req: TypesGen.UpdateChatAgentModelOverrideRequest,
|
||||
options?: MutationCallbacks,
|
||||
) => void;
|
||||
|
||||
export interface AgentSettingsAgentsPageViewProps {
|
||||
exploreModelOverrideData:
|
||||
| TypesGen.ChatExploreModelOverrideResponse
|
||||
| undefined;
|
||||
generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
||||
exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse;
|
||||
modelConfigsData: TypesGen.ChatModelConfig[] | undefined;
|
||||
modelConfigsError: unknown;
|
||||
isLoadingModelConfigs: boolean;
|
||||
onSaveExploreModelOverride: (
|
||||
req: TypesGen.UpdateChatExploreModelOverrideRequest,
|
||||
options?: MutationCallbacks,
|
||||
) => void;
|
||||
onSaveGeneralModelOverride?: SaveModelOverride;
|
||||
isSavingGeneralModelOverride?: boolean;
|
||||
isSaveGeneralModelOverrideError?: boolean;
|
||||
onSaveExploreModelOverride: SaveModelOverride;
|
||||
isSavingExploreModelOverride: boolean;
|
||||
isSaveExploreModelOverrideError: boolean;
|
||||
}
|
||||
@@ -26,37 +28,84 @@ export interface AgentSettingsAgentsPageViewProps {
|
||||
export const AgentSettingsAgentsPageView: FC<
|
||||
AgentSettingsAgentsPageViewProps
|
||||
> = ({
|
||||
generalModelOverrideData,
|
||||
exploreModelOverrideData,
|
||||
modelConfigsData,
|
||||
modelConfigsError,
|
||||
isLoadingModelConfigs,
|
||||
onSaveGeneralModelOverride,
|
||||
isSavingGeneralModelOverride = false,
|
||||
isSaveGeneralModelOverrideError = false,
|
||||
onSaveExploreModelOverride,
|
||||
isSavingExploreModelOverride,
|
||||
isSaveExploreModelOverrideError,
|
||||
}) => {
|
||||
const enabledModelConfigs = (modelConfigsData ?? []).filter(
|
||||
(modelConfig) => modelConfig.enabled,
|
||||
);
|
||||
const showGeneralModelSection =
|
||||
onSaveGeneralModelOverride !== undefined ||
|
||||
generalModelOverrideData !== undefined ||
|
||||
isSavingGeneralModelOverride ||
|
||||
isSaveGeneralModelOverrideError;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-8">
|
||||
<SectionHeader
|
||||
label="Agents"
|
||||
description="Configure defaults for delegated agents and other agent-specific capabilities."
|
||||
/>
|
||||
<div className="flex flex-col gap-3">
|
||||
{showGeneralModelSection && onSaveGeneralModelOverride && (
|
||||
<section aria-label="General model" className="flex flex-col gap-3">
|
||||
<SectionHeader
|
||||
label="General model"
|
||||
description="Deployment-wide model override for delegated subagents with write capabilities, such as editing files or running commands in the workspace."
|
||||
level="section"
|
||||
/>
|
||||
<SubagentModelOverrideSettings
|
||||
title="General model"
|
||||
description="Deployment-wide model override for delegated subagents with write capabilities, such as editing files or running commands in the workspace."
|
||||
modelOverrideData={generalModelOverrideData}
|
||||
enabledModelConfigs={enabledModelConfigs}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoading={isLoadingModelConfigs}
|
||||
onSaveModelOverride={onSaveGeneralModelOverride}
|
||||
isSaving={isSavingGeneralModelOverride}
|
||||
isSaveError={isSaveGeneralModelOverrideError}
|
||||
saveErrorMessage="Failed to save general model override."
|
||||
showHeader={false}
|
||||
/>
|
||||
</section>
|
||||
)}
|
||||
<section
|
||||
aria-label="Explore subagent model"
|
||||
className="flex flex-col gap-3"
|
||||
>
|
||||
<SectionHeader
|
||||
label="Explore subagent model"
|
||||
description="Optional deployment-wide model override for read-only Explore subagents."
|
||||
description="Deployment-wide model override for read-only Explore subagents."
|
||||
level="section"
|
||||
/>
|
||||
<ExploreModelOverrideSettings
|
||||
exploreModelOverrideData={exploreModelOverrideData}
|
||||
modelConfigs={modelConfigsData ?? []}
|
||||
<SubagentModelOverrideSettings
|
||||
title="Explore subagent model"
|
||||
description={
|
||||
<>
|
||||
Deployment-wide model override for read-only Explore subagents
|
||||
launched through the <code>spawn_agent</code> tool with a
|
||||
<code>type=explore</code> argument.
|
||||
</>
|
||||
}
|
||||
modelOverrideData={exploreModelOverrideData}
|
||||
enabledModelConfigs={enabledModelConfigs}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoadingModelConfigs={isLoadingModelConfigs}
|
||||
onSaveExploreModelOverride={onSaveExploreModelOverride}
|
||||
isSavingExploreModelOverride={isSavingExploreModelOverride}
|
||||
isSaveExploreModelOverrideError={isSaveExploreModelOverrideError}
|
||||
isLoading={isLoadingModelConfigs}
|
||||
onSaveModelOverride={onSaveExploreModelOverride}
|
||||
isSaving={isSavingExploreModelOverride}
|
||||
isSaveError={isSaveExploreModelOverrideError}
|
||||
saveErrorMessage="Failed to save Explore model override."
|
||||
showHeader={false}
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -156,7 +156,11 @@ const fixedNow = dayjs("2026-03-12T12:00:00");
|
||||
|
||||
const AgentsRouteElement = () => (
|
||||
<AgentSettingsAgentsPageView
|
||||
exploreModelOverrideData={{ has_malformed_override: false }}
|
||||
exploreModelOverrideData={{
|
||||
context: "explore",
|
||||
model_config_id: "",
|
||||
is_malformed: false,
|
||||
}}
|
||||
modelConfigsData={[]}
|
||||
modelConfigsError={undefined}
|
||||
isLoadingModelConfigs={false}
|
||||
@@ -609,10 +613,23 @@ export const WithErrorReasons: Story = {
|
||||
|
||||
const openSettingsView = async (canvasElement: HTMLElement) => {
|
||||
const canvas = within(canvasElement);
|
||||
const link = await waitFor(() =>
|
||||
canvas.getByRole("link", { name: "Settings" }),
|
||||
const settingsLink = canvas.queryByRole("link", { name: "Settings" });
|
||||
if (settingsLink) {
|
||||
await userEvent.click(settingsLink);
|
||||
return;
|
||||
}
|
||||
|
||||
const mobileMoreOptionsButton = canvas
|
||||
.getAllByRole("button", { name: "More options" })
|
||||
.find((button) => button.getAttribute("aria-haspopup") === "menu");
|
||||
if (!mobileMoreOptionsButton) {
|
||||
throw new Error("Expected a mobile More options menu button.");
|
||||
}
|
||||
await userEvent.click(mobileMoreOptionsButton);
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(
|
||||
await body.findByRole("menuitem", { name: "Settings" }),
|
||||
);
|
||||
await userEvent.click(link);
|
||||
};
|
||||
|
||||
export const OpensAnalyticsForAdmins: Story = {
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
import { useFormik } from "formik";
|
||||
import type { FC } from "react";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { Alert, AlertDescription } from "#/components/Alert/Alert";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import type { ModelSelectorOption } from "./ChatElements/ModelSelector";
|
||||
import { ModelSelector } from "./ChatElements/ModelSelector";
|
||||
|
||||
interface MutationCallbacks {
|
||||
onSuccess?: () => void;
|
||||
onError?: () => void;
|
||||
}
|
||||
|
||||
interface ExploreModelOverrideSettingsProps {
|
||||
exploreModelOverrideData:
|
||||
| TypesGen.ChatExploreModelOverrideResponse
|
||||
| undefined;
|
||||
modelConfigs: readonly TypesGen.ChatModelConfig[];
|
||||
modelConfigsError: unknown;
|
||||
isLoadingModelConfigs: boolean;
|
||||
onSaveExploreModelOverride: (
|
||||
req: TypesGen.UpdateChatExploreModelOverrideRequest,
|
||||
options?: MutationCallbacks,
|
||||
) => void;
|
||||
isSavingExploreModelOverride: boolean;
|
||||
isSaveExploreModelOverrideError: boolean;
|
||||
showHeader?: boolean;
|
||||
}
|
||||
|
||||
const toModelSelectorOption = (
|
||||
modelConfig: TypesGen.ChatModelConfig,
|
||||
): ModelSelectorOption => ({
|
||||
id: modelConfig.id,
|
||||
provider: modelConfig.provider,
|
||||
model: modelConfig.model,
|
||||
displayName: modelConfig.display_name.trim() || modelConfig.model,
|
||||
contextLimit: modelConfig.context_limit,
|
||||
});
|
||||
|
||||
export const ExploreModelOverrideSettings: FC<
|
||||
ExploreModelOverrideSettingsProps
|
||||
> = ({
|
||||
exploreModelOverrideData,
|
||||
modelConfigs,
|
||||
modelConfigsError,
|
||||
isLoadingModelConfigs,
|
||||
onSaveExploreModelOverride,
|
||||
isSavingExploreModelOverride,
|
||||
isSaveExploreModelOverrideError,
|
||||
showHeader = true,
|
||||
}) => {
|
||||
const hasLoadedExploreModelOverride = exploreModelOverrideData !== undefined;
|
||||
const enabledModelOptions = modelConfigs
|
||||
.filter((modelConfig) => modelConfig.enabled)
|
||||
.map(toModelSelectorOption);
|
||||
|
||||
const form = useFormik({
|
||||
enableReinitialize: true,
|
||||
initialValues: {
|
||||
model_config_id: exploreModelOverrideData?.model_config_id ?? "",
|
||||
},
|
||||
onSubmit: (values, { resetForm }) => {
|
||||
onSaveExploreModelOverride(
|
||||
{
|
||||
model_config_id: values.model_config_id || undefined,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
resetForm({ values });
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const isUnavailableSavedModel =
|
||||
form.values.model_config_id !== "" &&
|
||||
!enabledModelOptions.some(
|
||||
(option) => option.id === form.values.model_config_id,
|
||||
);
|
||||
const hasMalformedOverride =
|
||||
exploreModelOverrideData?.has_malformed_override ?? false;
|
||||
const isExploreModelOverrideDisabled =
|
||||
isSavingExploreModelOverride ||
|
||||
isLoadingModelConfigs ||
|
||||
!hasLoadedExploreModelOverride;
|
||||
const canSaveExploreModelOverride =
|
||||
hasLoadedExploreModelOverride && (form.dirty || hasMalformedOverride);
|
||||
|
||||
return (
|
||||
<form className="space-y-2" onSubmit={form.handleSubmit}>
|
||||
{showHeader && (
|
||||
<>
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Explore subagent model
|
||||
</h3>
|
||||
</div>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Optional deployment-wide model override for read-only Explore
|
||||
subagents spawned with <code>spawn_agent</code> using
|
||||
<code>type=explore</code>.
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
<div className="rounded-lg border border-border bg-surface-primary px-3 py-2">
|
||||
<ModelSelector
|
||||
options={enabledModelOptions}
|
||||
value={form.values.model_config_id}
|
||||
onValueChange={(value) =>
|
||||
form.setFieldValue("model_config_id", value)
|
||||
}
|
||||
disabled={isExploreModelOverrideDisabled}
|
||||
placeholder={
|
||||
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
|
||||
}
|
||||
emptyMessage={
|
||||
isLoadingModelConfigs
|
||||
? "Loading models..."
|
||||
: "No enabled models found."
|
||||
}
|
||||
className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm"
|
||||
contentClassName="min-w-[18rem]"
|
||||
/>
|
||||
</div>
|
||||
{isUnavailableSavedModel && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
The saved model is no longer enabled and will be ignored until you
|
||||
choose a new override.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{hasMalformedOverride && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
The saved override is malformed and is being treated as unset. Click
|
||||
Save to clear it.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{Boolean(modelConfigsError) && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load model configs.
|
||||
</p>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
form.setFieldValue("model_config_id", "");
|
||||
}}
|
||||
disabled={isExploreModelOverrideDisabled}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={
|
||||
isExploreModelOverrideDisabled || !canSaveExploreModelOverride
|
||||
}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveExploreModelOverrideError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save Explore model override.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,174 @@
|
||||
import { useFormik } from "formik";
|
||||
import type { FC, ReactNode } from "react";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { Alert, AlertDescription } from "#/components/Alert/Alert";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import type { ModelSelectorOption } from "./ChatElements/ModelSelector";
|
||||
import { ModelSelector } from "./ChatElements/ModelSelector";
|
||||
|
||||
export interface MutationCallbacks {
|
||||
onSuccess?: () => void;
|
||||
onError?: () => void;
|
||||
}
|
||||
|
||||
interface ModelOverrideData {
|
||||
readonly model_config_id: string;
|
||||
readonly is_malformed: boolean;
|
||||
}
|
||||
|
||||
interface UpdateModelOverrideRequest {
|
||||
readonly model_config_id: string;
|
||||
}
|
||||
|
||||
interface SubagentModelOverrideSettingsProps {
|
||||
title: string;
|
||||
description: ReactNode;
|
||||
modelOverrideData: ModelOverrideData | undefined;
|
||||
enabledModelConfigs: readonly TypesGen.ChatModelConfig[];
|
||||
modelConfigsError: unknown;
|
||||
isLoading: boolean;
|
||||
onSaveModelOverride: (
|
||||
req: UpdateModelOverrideRequest,
|
||||
options?: MutationCallbacks,
|
||||
) => void;
|
||||
isSaving: boolean;
|
||||
isSaveError: boolean;
|
||||
saveErrorMessage: string;
|
||||
showHeader?: boolean;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const toModelSelectorOption = (
|
||||
modelConfig: TypesGen.ChatModelConfig,
|
||||
): ModelSelectorOption => ({
|
||||
id: modelConfig.id,
|
||||
provider: modelConfig.provider,
|
||||
model: modelConfig.model,
|
||||
displayName: modelConfig.display_name.trim() || modelConfig.model,
|
||||
contextLimit: modelConfig.context_limit,
|
||||
});
|
||||
|
||||
export const SubagentModelOverrideSettings: FC<
|
||||
SubagentModelOverrideSettingsProps
|
||||
> = ({
|
||||
title,
|
||||
description,
|
||||
modelOverrideData,
|
||||
enabledModelConfigs,
|
||||
modelConfigsError,
|
||||
isLoading,
|
||||
onSaveModelOverride,
|
||||
isSaving,
|
||||
isSaveError,
|
||||
saveErrorMessage,
|
||||
showHeader = true,
|
||||
disabled = false,
|
||||
}) => {
|
||||
const hasLoadedModelOverride = modelOverrideData !== undefined;
|
||||
const enabledModelOptions = enabledModelConfigs.map(toModelSelectorOption);
|
||||
|
||||
const form = useFormik({
|
||||
enableReinitialize: true,
|
||||
initialValues: {
|
||||
model_config_id: modelOverrideData?.model_config_id ?? "",
|
||||
},
|
||||
onSubmit: (values, { resetForm }) => {
|
||||
onSaveModelOverride(
|
||||
{
|
||||
model_config_id: values.model_config_id,
|
||||
},
|
||||
{
|
||||
onSuccess: () => {
|
||||
resetForm({ values });
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const isUnavailableSavedModel =
|
||||
form.values.model_config_id !== "" &&
|
||||
!enabledModelOptions.some(
|
||||
(option) => option.id === form.values.model_config_id,
|
||||
);
|
||||
const isMalformedOverride = modelOverrideData?.is_malformed ?? false;
|
||||
const isModelOverrideDisabled =
|
||||
disabled || isSaving || isLoading || !hasLoadedModelOverride;
|
||||
const canSaveModelOverride =
|
||||
hasLoadedModelOverride && (form.dirty || isMalformedOverride);
|
||||
|
||||
return (
|
||||
<form aria-label={title} className="space-y-2" onSubmit={form.handleSubmit}>
|
||||
{showHeader && (
|
||||
<>
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
{title}
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
{description}
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
<ModelSelector
|
||||
options={enabledModelOptions}
|
||||
value={form.values.model_config_id}
|
||||
onValueChange={(value) => form.setFieldValue("model_config_id", value)}
|
||||
disabled={isModelOverrideDisabled}
|
||||
placeholder={
|
||||
isUnavailableSavedModel ? "Unavailable model" : "Use chat default"
|
||||
}
|
||||
emptyMessage={
|
||||
isLoading ? "Loading models..." : "No enabled models found."
|
||||
}
|
||||
className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm"
|
||||
contentClassName="min-w-[18rem]"
|
||||
/>
|
||||
{isUnavailableSavedModel && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
The saved model is no longer enabled and will be ignored until you
|
||||
choose a new override.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{isMalformedOverride && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
The saved override is malformed and is being treated as unset. Click
|
||||
Save to clear it.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{Boolean(modelConfigsError) && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load model configs.
|
||||
</p>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
form.setFieldValue("model_config_id", "");
|
||||
}}
|
||||
disabled={isModelOverrideDisabled}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isModelOverrideDisabled || !canSaveModelOverride}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
{saveErrorMessage}
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
);
|
||||
};
|
||||
Reference in New Issue
Block a user