diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 264fcc9648..fd0247001b 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1226,6 +1226,73 @@ func TestListChatModelConfigs(t *testing.T) { require.True(t, found) }) + t.Run("AdminIncludesDisabledModelConfigs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + contextLimit := int64(4096) + enabled := false + disabledConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "gpt-4o-disabled", + DisplayName: "GPT-4o Disabled", + Enabled: &enabled, + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + require.False(t, disabledConfig.Enabled) + + configs, err := client.ListChatModelConfigs(ctx) + require.NoError(t, err) + + found := false + for _, config := range configs { + if config.ID == disabledConfig.ID { + found = true + require.False(t, config.Enabled) + require.Equal(t, disabledConfig.DisplayName, config.DisplayName) + } + } + require.True(t, found) + }) + + t.Run("NonAdminExcludesDisabledModelConfigs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + enabledConfig := createChatModelConfig(t, adminClient) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + contextLimit := int64(4096) + enabled := false + _, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "gpt-4o-disabled", + DisplayName: "GPT-4o Disabled", + Enabled: &enabled, + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + + configs, err := memberClient.ListChatModelConfigs(ctx) + require.NoError(t, err) + require.Len(t, configs, 1) + require.Equal(t, enabledConfig.ID, configs[0].ID) + require.True(t, configs[0].Enabled) + }) + t.Run("DeserializesLegacyPricingJSON", func(t *testing.T) { t.Parallel() @@ -1469,6 +1536,102 @@ func TestUpdateChatModelConfig(t *testing.T) { requireChatModelPricing(t, configs[0].ModelConfig, pricing) }) + t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + modelConfig := createChatModelConfig(t, adminClient) + + enabled := false + updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ + Enabled: &enabled, + }) + require.NoError(t, err) + require.Equal(t, modelConfig.ID, updated.ID) + require.False(t, updated.Enabled) + + adminConfigs, err := adminClient.ListChatModelConfigs(ctx) + require.NoError(t, err) + + foundForAdmin := false + for _, config := range adminConfigs { + if config.ID == modelConfig.ID { + foundForAdmin = true + require.False(t, config.Enabled) + } + } + require.True(t, foundForAdmin) + + memberConfigs, err := memberClient.ListChatModelConfigs(ctx) + require.NoError(t, err) + for _, config := range memberConfigs { + require.NotEqual(t, modelConfig.ID, config.ID) + } + }) + + t.Run("ReEnableRestoresVisibilityForNonAdmins", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + _, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + }) + require.NoError(t, err) + + contextLimit := int64(4096) + enabled := false + modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "gpt-4o-reenable", + DisplayName: "GPT-4o Re-enable", + Enabled: &enabled, + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + require.False(t, modelConfig.Enabled) + + memberConfigs, err := memberClient.ListChatModelConfigs(ctx) + require.NoError(t, err) + + foundForMember := false + for _, config := range memberConfigs { + if config.ID == modelConfig.ID { + foundForMember = true + } + } + require.False(t, foundForMember) + + enabled = true + updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ + Enabled: &enabled, + }) + require.NoError(t, err) + require.Equal(t, modelConfig.ID, updated.ID) + require.True(t, updated.Enabled) + + memberConfigs, err = memberClient.ListChatModelConfigs(ctx) + require.NoError(t, err) + + foundForMember = false + for _, config := range memberConfigs { + if config.ID == modelConfig.ID { + foundForMember = true + require.True(t, config.Enabled) + } + } + require.True(t, foundForMember) + }) + t.Run("RejectsNegativePricing", func(t *testing.T) { t.Parallel() diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx index 4aaa768a14..fe089dcfc0 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx @@ -123,6 +123,7 @@ const setupChatSpies = (state: { provider: req.provider, model: req.model, display_name: req.display_name || req.model, + enabled: req.enabled ?? true, context_limit: typeof req.context_limit === "number" && Number.isFinite(req.context_limit) @@ -152,12 +153,29 @@ const setupChatSpies = (state: { spyOn(API.experimental, "deleteChatProviderConfig").mockResolvedValue( undefined, ); - spyOn(API.experimental, "updateChatModelConfig").mockResolvedValue( - createModelConfig({ - id: "stub", - provider: "stub", - model: "stub", - }), + spyOn(API.experimental, "updateChatModelConfig").mockImplementation( + async (modelConfigId, req) => { + const idx = state.modelConfigs.findIndex((m) => m.id === modelConfigId); + if (idx < 0) { + throw new Error("Model config not found."); + } + + const current = state.modelConfigs[idx]; + const updated = createModelConfig({ + ...current, + ...req, + id: current.id, + provider: current.provider, + model: current.model, + updated_at: now, + }); + + state.modelConfigs = state.modelConfigs.map((modelConfig, i) => + i === idx ? updated : modelConfig, + ); + + return updated; + }, ); }; @@ -507,6 +525,58 @@ export const SubmitModelConfigExplicitly: Story = { }, }; +export const UpdateModelEnabledToggle: Story = { + args: { section: "models" as ChatModelAdminSection }, + beforeEach: () => { + setupChatSpies({ + providerConfigs: [ + createProviderConfig({ + id: "provider-openai", + provider: "openai", + display_name: "OpenAI", + source: "database", + has_api_key: true, + }), + ], + modelConfigs: [ + createModelConfig({ + id: "model-enabled", + provider: "openai", + model: "gpt-test-enabled", + display_name: "GPT Test Enabled", + enabled: true, + }), + ], + modelCatalog: { providers: [] }, + }); + }, + play: async ({ canvasElement }) => { + const body = within(canvasElement.ownerDocument.body); + + await userEvent.click(await body.findByText("GPT Test Enabled")); + + const enabledSwitch = await body.findByRole("switch", { name: "Enabled" }); + await expect(enabledSwitch).toBeChecked(); + await userEvent.click(enabledSwitch); + await expect(enabledSwitch).not.toBeChecked(); + + await userEvent.click(body.getByRole("button", { name: "Save" })); + + await waitFor(() => { + expect(API.experimental.updateChatModelConfig).toHaveBeenCalledTimes(1); + }); + expect(API.experimental.updateChatModelConfig).toHaveBeenCalledWith( + "model-enabled", + expect.objectContaining({ enabled: false }), + ); + + const modelRow = await body.findByRole("button", { + name: /gpt test enabled/i, + }); + await expect(within(modelRow).getByText("disabled")).toBeVisible(); + }, +}; + // ── Per-provider model form stories ──────────────────────────── // Each story opens the "Add model" form for a specific provider // so you can visually verify the schema-driven fields render. diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelForm.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelForm.tsx index b97bc89bce..006ad72510 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelForm.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelForm.tsx @@ -20,6 +20,12 @@ import { SelectValue, } from "#/components/Select/Select"; import { Spinner } from "#/components/Spinner/Spinner"; +import { Switch } from "#/components/Switch/Switch"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "#/components/Tooltip/Tooltip"; import type { ProviderState } from "./ChatModelAdminPanel"; import { GeneralModelConfigFields, @@ -40,6 +46,7 @@ import { ProviderIcon } from "./ProviderIcon"; const validationSchema = Yup.object({ model: Yup.string().trim().required("Model ID is required."), displayName: Yup.string(), + enabled: Yup.boolean(), contextLimit: Yup.string() .required("Context limit is required.") .test( @@ -93,6 +100,7 @@ export const ModelForm: FC = ({ onDeleteModel, }) => { const isEditing = Boolean(editingModel); + const isDefaultModel = isEditing && editingModel?.is_default === true; const [showPricing, setShowPricing] = useState(false); const [showAdvanced, setShowAdvanced] = useState(false); const [confirmingDelete, setConfirmingDelete] = useState(false); @@ -135,6 +143,9 @@ export const ModelForm: FC = ({ ...(trimmedDisplayName !== (editingModel.display_name ?? "") && { display_name: trimmedDisplayName, }), + ...(values.enabled !== editingModel.enabled && { + enabled: values.enabled, + }), ...(parsedContextLimit !== null && parsedContextLimit !== editingModel.context_limit && { context_limit: parsedContextLimit, @@ -158,6 +169,7 @@ export const ModelForm: FC = ({ const req: TypesGen.CreateChatModelConfigRequest = { provider: selectedProviderState.provider, model: trimmedModel, + enabled: true, ...(parsedContextLimit !== null && { context_limit: parsedContextLimit, }), @@ -192,6 +204,7 @@ export const ModelForm: FC = ({ const hasFieldErrors = Object.keys(modelConfigFormBuildResult.fieldErrors).length > 0; + const defaultModelDisableGuard = isDefaultModel && form.values.enabled; // ── Provider select (shared across all form states) ─────── @@ -314,6 +327,29 @@ export const ModelForm: FC = ({ } /> + {editingModel && ( + + + + { + form.setFieldValue("enabled", v); + }} + aria-label="Enabled" + disabled={isSaving || defaultModelDisableGuard} + /> + + + + {defaultModelDisableGuard + ? "Default model cannot be disabled. Remove default status first." + : form.values.enabled + ? "Disable this model. It will be hidden from users." + : "Enable this model. It will be visible to users."} + + + )}
diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx index de33ba174f..629b0d5980 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx @@ -218,7 +218,7 @@ export const ModelsSection: FC = ({ ); const handleSetDefault = (modelConfig: TypesGen.ChatModelConfig) => { - if (modelConfig.is_default) return; + if (modelConfig.is_default || !modelConfig.enabled) return; void onUpdateModel(modelConfig.id, { is_default: true }); }; @@ -310,7 +310,11 @@ export const ModelsSection: FC = ({ e.stopPropagation(); handleSetDefault(modelConfig); }} - aria-disabled={isUpdating || modelConfig.is_default} + aria-disabled={ + isUpdating || + modelConfig.is_default || + !modelConfig.enabled + } aria-label={ modelConfig.is_default ? "Default model" @@ -320,7 +324,9 @@ export const ModelsSection: FC = ({ "flex shrink-0 items-center justify-center bg-transparent border-0 p-0 transition-colors", modelConfig.is_default ? "text-content-primary" - : "cursor-pointer text-content-secondary/30 hover:text-content-secondary", + : !modelConfig.enabled + ? "cursor-not-allowed text-content-secondary/30" + : "cursor-pointer text-content-secondary/30 hover:text-content-secondary", )} > = ({ - {modelConfig.is_default - ? "Pinned as default for new chats" - : "Pin as default for new chats"} + {!modelConfig.enabled + ? "Cannot set a disabled model as default" + : modelConfig.is_default + ? "Pinned as default for new chats" + : "Pin as default for new chats"} diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.test.ts b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.test.ts index 8305cadcad..7f7b77aff1 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.test.ts +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it } from "vitest"; import type * as TypesGen from "#/api/typesGenerated"; import { + buildInitialModelFormValues, buildModelConfigFromForm, emptyModelConfigFormState, extractModelConfigFormState, @@ -90,6 +91,35 @@ const baseChatModelConfig: TypesGen.ChatModelConfig = { updated_at: "2025-01-01T00:00:00Z", }; +// ── buildInitialModelFormValues ──────────────────────────────── + +describe("buildInitialModelFormValues", () => { + it("returns create mode defaults including enabled=true", () => { + expect(buildInitialModelFormValues()).toEqual({ + model: "", + displayName: "", + enabled: true, + contextLimit: "", + compressionThreshold: "", + isDefault: false, + config: emptyModelConfigFormState, + }); + }); + + it("preserves enabled=true when editing an enabled model", () => { + expect(buildInitialModelFormValues(baseChatModelConfig).enabled).toBe(true); + }); + + it("preserves enabled=false when editing a disabled model", () => { + expect( + buildInitialModelFormValues({ + ...baseChatModelConfig, + enabled: false, + }).enabled, + ).toBe(false); + }); +}); + // ── parsePositiveInteger ─────────────────────────────────────── describe("parsePositiveInteger", () => { diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.ts b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.ts index c0cc20857d..d832473387 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.ts +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/modelConfigFormLogic.ts @@ -25,6 +25,7 @@ export type ModelConfigFormBuildResult = { export type ModelFormValues = { model: string; displayName: string; + enabled: boolean; contextLimit: string; compressionThreshold: string; isDefault: boolean; @@ -224,6 +225,7 @@ export const buildInitialModelFormValues = ( ): ModelFormValues => ({ model: editingModel?.model ?? "", displayName: editingModel?.display_name ?? "", + enabled: editingModel?.enabled ?? true, contextLimit: editingModel ? String(editingModel.context_limit) : "", compressionThreshold: editingModel ? String(editingModel.compression_threshold)