From 796872f4de2c0df55a1ffaf616fe7788305e922c Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Wed, 25 Mar 2026 15:19:17 +0000 Subject: [PATCH] feat: add deployment-wide template allowlist for chats (#23262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Stores a deployment-wide agents template allowlist in `site_configs` (`agents_template_allowlist`) - Adds `GET/PUT /api/experimental/chats/config/template-allowlist` endpoints - Filters `list_templates`, `read_template`, and `create_workspace` chat tools by allowlist, if defined (empty=all allowed) - Add "Templates" admin settings tab in Agents UI ([what it looks like](https://624de63c6aacee003aa84340-sitjilsyrr.chromatic.com/?path=/story/pages-agentspage-agentsettingspageview--template-allowlist)) > 🤖 This PR was created with the help of Coder Agents, and has been reviewed by my human. 🧑‍💻 --- coderd/coderd.go | 2 + coderd/database/dbauthz/dbauthz.go | 18 ++ coderd/database/dbauthz/dbauthz_test.go | 8 + coderd/database/dbmetrics/querymetrics.go | 16 ++ coderd/database/dbmock/dbmock.go | 29 +++ coderd/database/querier.go | 4 + coderd/database/queries.sql.go | 24 +++ coderd/database/queries/siteconfig.sql | 10 + coderd/exp_chats.go | 148 ++++++++++++++ coderd/exp_chats_test.go | 128 ++++++++++++ coderd/x/chatd/chatd.go | 54 ++++- coderd/x/chatd/chattool/chattool.go | 15 ++ coderd/x/chatd/chattool/createworkspace.go | 6 +- coderd/x/chatd/chattool/listtemplates.go | 14 +- coderd/x/chatd/chattool/listtemplates_test.go | 188 ++++++++++++++++++ coderd/x/chatd/chattool/readtemplate.go | 9 +- codersdk/chats.go | 34 ++++ site/src/api/api.ts | 17 ++ site/src/api/queries/chats.ts | 16 ++ site/src/api/typesGenerated.ts | 10 + .../AgentSettingsPageView.stories.tsx | 149 +++++++++++++- .../AgentsPage/AgentSettingsPageView.tsx | 147 ++++++++++++++ .../components/Sidebar/AgentsSidebar.tsx | 9 + 23 files changed, 1045 insertions(+), 10 deletions(-) create mode 100644 coderd/x/chatd/chattool/listtemplates_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 5f21cb40d2..118493790c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1186,6 +1186,8 @@ func New(options *Options) *API { r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold) r.Get("/workspace-ttl", api.getChatWorkspaceTTL) r.Put("/workspace-ttl", api.putChatWorkspaceTTL) + r.Get("/template-allowlist", api.getChatTemplateAllowlist) + r.Put("/template-allowlist", api.putChatTemplateAllowlist) }) // TODO(cian): place under /api/experimental/chats/config r.Route("/providers", func(r chi.Router) { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index e33f160e1a..347b5a1a1f 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2674,6 +2674,17 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) { return q.db.GetChatSystemPrompt(ctx) } +// GetChatTemplateAllowlist requires deployment-config read permission, +// unlike the peer getters (GetChatDesktopEnabled, etc.) which only +// check actor presence. The allowlist is admin-configuration that +// should not be readable by non-admin users via the HTTP API. +func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return "", err + } + return q.db.GetChatTemplateAllowlist(ctx) +} + func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return database.ChatUsageLimitConfig{}, err @@ -6812,6 +6823,13 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro return q.db.UpsertChatSystemPrompt(ctx, value) } +func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist) +} + func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return database.ChatUsageLimitConfig{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index a8d4d197d2..20640ed74d 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -656,6 +656,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes() check.Args().Asserts() })) + s.Run("GetChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatTemplateAllowlist(gomock.Any()).Return("", nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) + })) s.Run("GetChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatWorkspaceTTL(gomock.Any()).Return("1h", nil).AnyTimes() check.Args().Asserts() @@ -873,6 +877,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("UpsertChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatTemplateAllowlist(gomock.Any(), "").Return(nil).AnyTimes() + check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("UpsertChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatWorkspaceTTL(gomock.Any(), "1h").Return(nil).AnyTimes() check.Args("1h").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index c41e2a4647..2890e31535 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1208,6 +1208,14 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err return r0, r1 } +func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatTemplateAllowlist(ctx) + m.queryLatencies.WithLabelValues("GetChatTemplateAllowlist").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTemplateAllowlist").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { start := time.Now() r0, r1 := m.s.GetChatUsageLimitConfig(ctx) @@ -4808,6 +4816,14 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str return r0 } +func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error { + start := time.Now() + r0 := m.s.UpsertChatTemplateAllowlist(ctx, templateAllowlist) + m.queryLatencies.WithLabelValues("UpsertChatTemplateAllowlist").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTemplateAllowlist").Inc() + return r0 +} + func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { start := time.Now() r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index a061d6ebab..7dff5e62bb 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2223,6 +2223,21 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx) } +// GetChatTemplateAllowlist mocks base method. +func (m *MockStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatTemplateAllowlist", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatTemplateAllowlist indicates an expected call of GetChatTemplateAllowlist. +func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx) +} + // GetChatUsageLimitConfig mocks base method. func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) { m.ctrl.T.Helper() @@ -9013,6 +9028,20 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value) } +// UpsertChatTemplateAllowlist mocks base method. +func (m *MockStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatTemplateAllowlist", ctx, templateAllowlist) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatTemplateAllowlist indicates an expected call of UpsertChatTemplateAllowlist. +func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowlist any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist) +} + // UpsertChatUsageLimitConfig mocks base method. func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 78cad95f85..1f8c00bc6a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -254,6 +254,9 @@ type sqlcQuerier interface { GetChatProviders(ctx context.Context) ([]ChatProvider, error) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error) GetChatSystemPrompt(ctx context.Context) (string, error) + // GetChatTemplateAllowlist returns the JSON-encoded template allowlist. + // Returns an empty string when no allowlist has been configured (all templates allowed). + GetChatTemplateAllowlist(ctx context.Context) (string, error) GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error) @@ -933,6 +936,7 @@ type sqlcQuerier interface { UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error) UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error) UpsertChatSystemPrompt(ctx context.Context, value string) error + UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error) UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d0b4fa6924..47e2c01040 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -17512,6 +17512,20 @@ func (q *sqlQuerier) GetChatSystemPrompt(ctx context.Context) (string, error) { return chat_system_prompt, err } +const getChatTemplateAllowlist = `-- name: GetChatTemplateAllowlist :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist +` + +// GetChatTemplateAllowlist returns the JSON-encoded template allowlist. +// Returns an empty string when no allowlist has been configured (all templates allowed). +func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatTemplateAllowlist) + var template_allowlist string + err := row.Scan(&template_allowlist) + return template_allowlist, err +} + const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one SELECT COALESCE( @@ -17743,6 +17757,16 @@ func (q *sqlQuerier) UpsertChatSystemPrompt(ctx context.Context, value string) e return err } +const upsertChatTemplateAllowlist = `-- name: UpsertChatTemplateAllowlist :exec +INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_template_allowlist' +` + +func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error { + _, err := q.db.ExecContext(ctx, upsertChatTemplateAllowlist, templateAllowlist) + return err +} + const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec INSERT INTO site_configs (key, value) VALUES ('agents_workspace_ttl', $1::text) diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 37b96a6594..96ebfd5f52 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -161,6 +161,12 @@ SET value = CASE END WHERE site_configs.key = 'agents_desktop_enabled'; +-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist. +-- Returns an empty string when no allowlist has been configured (all templates allowed). +-- name: GetChatTemplateAllowlist :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist; + -- name: GetChatWorkspaceTTL :one -- Returns the global TTL for chat workspaces as a Go duration string. -- Returns "0s" (disabled) when no value has been configured. @@ -170,6 +176,10 @@ SELECT '0s' )::text AS workspace_ttl; +-- name: UpsertChatTemplateAllowlist :exec +INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', @template_allowlist) +ON CONFLICT (key) DO UPDATE SET value = @template_allowlist WHERE site_configs.key = 'agents_template_allowlist'; + -- name: UpsertChatWorkspaceTTL :exec INSERT INTO site_configs (key, value) VALUES ('agents_workspace_ttl', @workspace_ttl::text) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 5aa0fa0361..9d3540b11f 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -2777,6 +2777,154 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) } +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. +func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) { + httpapi.ResourceNotFound(rw) + return + } + raw, err := api.Database.GetChatTemplateAllowlist(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching chat template allowlist.", + Detail: err.Error(), + }) + return + } + ids, parseErr := parseChatTemplateAllowlist(raw) + if parseErr != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Stored template allowlist is corrupt.", + Detail: parseErr.Error(), + }) + return + } + resp := codersdk.ChatTemplateAllowlist{ + TemplateIDs: ids, + } + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.ResourceNotFound(rw) + return + } + + var req codersdk.ChatTemplateAllowlist + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Validate all entries are valid UUIDs and deduplicate. + seen := make(map[string]struct{}, len(req.TemplateIDs)) + deduped := make([]string, 0, len(req.TemplateIDs)) + for _, id := range req.TemplateIDs { + parsed, err := uuid.Parse(id) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid template ID in allowlist.", + Detail: fmt.Sprintf("%q is not a valid UUID.", id), + }) + return + } + // Canonicalize to lowercase so deduplication is + // case-insensitive and stored values are consistent. + canonical := parsed.String() + if _, ok := seen[canonical]; !ok { + seen[canonical] = struct{}{} + deduped = append(deduped, canonical) + } + } + + // Convert to UUIDs for the database query. + parsedUUIDs := make([]uuid.UUID, len(deduped)) + for i, s := range deduped { + // Already validated above, safe to ignore error. + parsedUUIDs[i], _ = uuid.Parse(s) + } + + raw, err := json.Marshal(deduped) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error encoding template allowlist.", + Detail: err.Error(), + }) + return + } + + err = api.Database.InTx(func(tx database.Store) error { + // Verify all IDs refer to existing, non-deprecated templates + // in a single query. + if len(parsedUUIDs) > 0 { + found, err := tx.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{ + IDs: parsedUUIDs, + Deprecated: sql.NullBool{ + Bool: false, + Valid: true, + }, + }) + if err != nil { + return xerrors.Errorf("fetch templates: %w", err) + } + if len(found) != len(parsedUUIDs) { + foundSet := make(map[uuid.UUID]struct{}, len(found)) + for _, t := range found { + foundSet[t.ID] = struct{}{} + } + var missing []string + for _, id := range parsedUUIDs { + if _, ok := foundSet[id]; !ok { + missing = append(missing, id.String()) + } + } + return xerrors.Errorf("templates not found or deprecated: %s", strings.Join(missing, ", ")) + } + } + return tx.UpsertChatTemplateAllowlist(ctx, string(raw)) + }, nil) + if err != nil { + // If the error mentions "not found or deprecated", it's a + // validation failure, not an internal error. + if strings.Contains(err.Error(), "not found or deprecated") { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "One or more templates not found or deprecated.", + Detail: err.Error(), + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error updating chat template allowlist.", + Detail: err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +// parseChatTemplateAllowlist parses the raw JSON string from the +// database into a list of template ID strings. Returns an empty +// slice when the value is empty. Returns an error when the stored +// JSON is corrupt or otherwise cannot be unmarshalled. +func parseChatTemplateAllowlist(raw string) ([]string, error) { + if raw == "" { + return []string{}, nil + } + var ids []string + if err := json.Unmarshal([]byte(raw), &ids); err != nil { + return nil, xerrors.Errorf("unmarshal template allowlist: %w", err) + } + if ids == nil { + return []string{}, nil + } + return ids, nil +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index efa27cb466..f987ee978e 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -5161,6 +5161,134 @@ func TestUserChatCompactionThresholds(t *testing.T) { }) } +//nolint:tparallel // Subtests share a single coderdtest instance and run sequentially. +func TestChatTemplateAllowlist(t *testing.T) { + t.Parallel() + + // Shared setup: one coderdtest instance with two real templates. + // Subtests that need valid template IDs use these. + client, store := newChatClientWithDatabase(t) + admin := coderdtest.CreateFirstUser(t, client.Client) + tmpl1 := dbgen.Template(t, store, database.Template{ + OrganizationID: admin.OrganizationID, + CreatedBy: admin.UserID, + }) + tmpl2 := dbgen.Template(t, store, database.Template{ + OrganizationID: admin.OrganizationID, + CreatedBy: admin.UserID, + }) + deprecatedTmpl := dbgen.Template(t, store, database.Template{ + OrganizationID: admin.OrganizationID, + CreatedBy: admin.UserID, + }) + //nolint:gocritic // Owner context needed to deprecate the template in test setup. + ownerRoles, err := rbac.RoleIdentifiers{rbac.RoleOwner()}.Expand() + require.NoError(t, err) + err = store.UpdateTemplateAccessControlByID(dbauthz.As(context.Background(), rbac.Subject{ + ID: "owner", + Roles: rbac.Roles(ownerRoles), + Scope: rbac.ExpandableScope(rbac.ScopeAll), + }), database.UpdateTemplateAccessControlByIDParams{ + ID: deprecatedTmpl.ID, + Deprecated: "this template is deprecated", + }) + require.NoError(t, err, "deprecate template") + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + resp, err := client.GetChatTemplateAllowlist(ctx) + require.NoError(t, err) + require.Empty(t, resp.TemplateIDs) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("AdminCanSet", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + ids := []string{tmpl1.ID.String(), tmpl2.ID.String()} + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: ids}) + require.NoError(t, err) + resp, err := client.GetChatTemplateAllowlist(ctx) + require.NoError(t, err) + require.ElementsMatch(t, ids, resp.TemplateIDs) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("AdminCanClear", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{}}) + require.NoError(t, err) + resp, err := client.GetChatTemplateAllowlist(ctx) + require.NoError(t, err) + require.Empty(t, resp.TemplateIDs) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("NonAdminReadFails", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + _, err := memberClient.GetChatTemplateAllowlist(ctx) + requireSDKError(t, err, http.StatusNotFound) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("NonAdminWriteFails", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + // Uses a random UUID — hits 404 before template validation. + err := memberClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) + requireSDKError(t, err, http.StatusNotFound) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("UnauthenticatedFails", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + anonClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) + // Uses a random UUID — hits 401 before template validation. + err := anonClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) + requireSDKError(t, err, http.StatusUnauthorized) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("InvalidUUIDRejected", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{"not-a-uuid"}}) + requireSDKError(t, err, http.StatusBadRequest) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("NonexistentTemplateRejected", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) + requireSDKError(t, err, http.StatusBadRequest) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("DeprecatedTemplateRejected", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{ + TemplateIDs: []string{deprecatedTmpl.ID.String()}, + }) + requireSDKError(t, err, http.StatusBadRequest) + }) + + //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. + t.Run("DeduplicatesIDs", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + id := tmpl1.ID.String() + err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{ + TemplateIDs: []string{id, id, id}, + }) + require.NoError(t, err) + resp, err := client.GetChatTemplateAllowlist(ctx) + require.NoError(t, err) + require.Len(t, resp.TemplateIDs, 1) + require.Equal(t, id, resp.TemplateIDs[0]) + }) +} + func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error { t.Helper() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 1a57516a43..695a1a7d9f 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -3371,6 +3371,45 @@ func (p *Server) runChat( GetWorkspaceConn: workspaceCtx.getWorkspaceConn, }), } + // getAllowedTemplateIDs returns the current deployment-wide + // template allowlist, re-reading from the database on each call + // so that admin changes take effect without restarting the chat. + // Returns nil (= all allowed) on errors to fail open. + getAllowedTemplateIDs := func() map[uuid.UUID]bool { + raw, err := p.db.GetChatTemplateAllowlist(ctx) + if err != nil { + p.logger.Error(ctx, "failed to load template allowlist, all templates will be allowed", slog.Error(err)) + return nil + } + if raw == "" { + return nil + } + var ids []string + if jsonErr := json.Unmarshal([]byte(raw), &ids); jsonErr != nil { + // Note: the API endpoint (GET /template-allowlist) returns + // HTTP 500 for corrupt JSON, giving admins visibility into + // the problem. The runtime path here deliberately fails open + // so that a corrupt allowlist doesn't block all chats. + p.logger.Error(ctx, "failed to parse template allowlist JSON, all templates will be allowed", + slog.F("raw", raw), slog.Error(jsonErr)) + return nil + } + allowlist := make(map[uuid.UUID]bool, len(ids)) + for _, s := range ids { + if id, parseErr := uuid.Parse(s); parseErr == nil { + allowlist[id] = true + } else { + p.logger.Warn(ctx, "ignoring invalid UUID in template allowlist", + slog.F("value", s), slog.Error(parseErr)) + } + } + if len(ids) > 0 && len(allowlist) == 0 { + p.logger.Error(ctx, "all UUIDs in template allowlist were invalid, all templates will be allowed", + slog.F("count", len(ids))) + return nil + } + return allowlist + } // Only root chats (not delegated subagents) get workspace // provisioning and subagent tools. Child agents must not // create workspaces or spawn further subagents — they should @@ -3379,12 +3418,14 @@ func (p *Server) runChat( // Workspace provisioning tools. tools = append(tools, chattool.ListTemplates(chattool.ListTemplatesOptions{ - DB: p.db, - OwnerID: chat.OwnerID, + DB: p.db, + OwnerID: chat.OwnerID, + AllowedTemplateIDs: getAllowedTemplateIDs, }), chattool.ReadTemplate(chattool.ReadTemplateOptions{ - DB: p.db, - OwnerID: chat.OwnerID, + DB: p.db, + OwnerID: chat.OwnerID, + AllowedTemplateIDs: getAllowedTemplateIDs, }), chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{ DB: p.db, @@ -3395,7 +3436,12 @@ func (p *Server) runChat( AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, WorkspaceMu: &workspaceMu, Logger: p.logger, + AllowedTemplateIDs: getAllowedTemplateIDs, }), + // StartWorkspace intentionally does not enforce the + // template allowlist. The allowlist restricts creation + // of new workspaces only — existing workspaces can + // be restarted regardless of allowlist changes. chattool.StartWorkspace(chattool.StartWorkspaceOptions{ DB: p.db, OwnerID: chat.OwnerID, diff --git a/coderd/x/chatd/chattool/chattool.go b/coderd/x/chatd/chattool/chattool.go index f12d6cbf90..05f5366845 100644 --- a/coderd/x/chatd/chattool/chattool.go +++ b/coderd/x/chatd/chattool/chattool.go @@ -5,6 +5,7 @@ import ( "unicode/utf8" "charm.land/fantasy" + "github.com/google/uuid" ) // toolResponse builds a fantasy.ToolResponse from a JSON-serializable @@ -31,3 +32,17 @@ func truncateRunes(value string, maxLen int) string { } return string(runes[:maxLen]) } + +// isTemplateAllowed checks whether a template ID is permitted by the +// configured allowlist. A nil function or an empty allowlist means +// all templates are allowed. +func isTemplateAllowed(getAllowlist func() map[uuid.UUID]bool, id uuid.UUID) bool { + if getAllowlist == nil { + return true + } + allowlist := getAllowlist() + if len(allowlist) == 0 { + return true + } + return allowlist[id] +} diff --git a/coderd/x/chatd/chattool/createworkspace.go b/coderd/x/chatd/chattool/createworkspace.go index 18fdc07e52..33f27285f9 100644 --- a/coderd/x/chatd/chattool/createworkspace.go +++ b/coderd/x/chatd/chattool/createworkspace.go @@ -67,6 +67,7 @@ type CreateWorkspaceOptions struct { AgentInactiveDisconnectTimeout time.Duration WorkspaceMu *sync.Mutex Logger slog.Logger + AllowedTemplateIDs func() map[uuid.UUID]bool } type createWorkspaceArgs struct { @@ -106,6 +107,10 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { ), nil } + if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) { + return fantasy.NewTextErrorResponse("template not available for chat workspaces; use list_templates to find allowed templates"), nil + } + // Serialize workspace creation to prevent parallel // tool calls from creating duplicate workspaces. if options.WorkspaceMu != nil { @@ -121,7 +126,6 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { if done { return toolResponse(existing), nil } - ownerID := options.OwnerID // Set up dbauthz context for DB lookups. diff --git a/coderd/x/chatd/chattool/listtemplates.go b/coderd/x/chatd/chattool/listtemplates.go index f11ef5d801..9c81412497 100644 --- a/coderd/x/chatd/chattool/listtemplates.go +++ b/coderd/x/chatd/chattool/listtemplates.go @@ -3,6 +3,8 @@ package chattool import ( "context" "database/sql" + "maps" + "slices" "sort" "strings" @@ -20,8 +22,9 @@ const listTemplatesPageSize = 10 // ListTemplatesOptions configures the list_templates tool. type ListTemplatesOptions struct { - DB database.Store - OwnerID uuid.UUID + DB database.Store + OwnerID uuid.UUID + AllowedTemplateIDs func() map[uuid.UUID]bool } type listTemplatesArgs struct { @@ -63,6 +66,13 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool { filterParams.FuzzyName = query } + var allowlist map[uuid.UUID]bool + if options.AllowedTemplateIDs != nil { + allowlist = options.AllowedTemplateIDs() + } + if len(allowlist) > 0 { + filterParams.IDs = slices.Collect(maps.Keys(allowlist)) + } templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams) if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil diff --git a/coderd/x/chatd/chattool/listtemplates_test.go b/coderd/x/chatd/chattool/listtemplates_test.go new file mode 100644 index 0000000000..251ded07d1 --- /dev/null +++ b/coderd/x/chatd/chattool/listtemplates_test.go @@ -0,0 +1,188 @@ +package chattool_test + +import ( + "context" + "encoding/json" + "testing" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +//nolint:tparallel,paralleltest // Subtests share a single DB and run sequentially. +func TestTemplateAllowlistEnforcement(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + + t1 := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + Name: "template-alpha", + }) + t2 := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + Name: "template-beta", + }) + + t.Run("ListTemplates", func(t *testing.T) { + t.Run("NoAllowlist", func(t *testing.T) { + tool := chattool.ListTemplates(chattool.ListTemplatesOptions{ + DB: db, + OwnerID: user.ID, + }) + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c1", Name: "list_templates", Input: "{}"}) + require.NoError(t, err) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + templates := result["templates"].([]any) + require.Len(t, templates, 2) + }) + + t.Run("EmptyAllowlist", func(t *testing.T) { + tool := chattool.ListTemplates(chattool.ListTemplatesOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{} }, + }) + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c2", Name: "list_templates", Input: "{}"}) + require.NoError(t, err) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + templates := result["templates"].([]any) + require.Len(t, templates, 2) + }) + + t.Run("OneMatch", func(t *testing.T) { + tool := chattool.ListTemplates(chattool.ListTemplatesOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} }, + }) + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c3", Name: "list_templates", Input: "{}"}) + require.NoError(t, err) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + templates := result["templates"].([]any) + require.Len(t, templates, 1) + m := templates[0].(map[string]any) + require.Equal(t, t1.ID.String(), m["id"].(string)) + }) + + t.Run("NoMatches", func(t *testing.T) { + tool := chattool.ListTemplates(chattool.ListTemplatesOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} }, + }) + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c4", Name: "list_templates", Input: "{}"}) + require.NoError(t, err) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + templates := result["templates"].([]any) + require.Empty(t, templates) + }) + }) + + t.Run("ReadTemplate", func(t *testing.T) { + t.Run("Allowed", func(t *testing.T) { + tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} }, + }) + input := `{"template_id":"` + t1.ID.String() + `"}` + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c5", Name: "read_template", Input: input}) + require.NoError(t, err) + require.False(t, resp.IsError) + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + tmplInfo := result["template"].(map[string]any) + require.Equal(t, t1.ID.String(), tmplInfo["id"].(string)) + }) + + t.Run("Disallowed", func(t *testing.T) { + tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} }, + }) + input := `{"template_id":"` + t2.ID.String() + `"}` + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c6", Name: "read_template", Input: input}) + require.NoError(t, err) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, "not found") + }) + + t.Run("NoAllowlist", func(t *testing.T) { + tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{ + DB: db, + OwnerID: user.ID, + }) + input := `{"template_id":"` + t2.ID.String() + `"}` + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c7", Name: "read_template", Input: input}) + require.NoError(t, err) + require.False(t, resp.IsError) + }) + }) + + t.Run("CreateWorkspace", func(t *testing.T) { + t.Run("Allowed", func(t *testing.T) { + createCalled := false + tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} }, + CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) { + createCalled = true + return codersdk.Workspace{}, nil + }, + }) + input := `{"template_id":"` + t1.ID.String() + `"}` + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8a", Name: "create_workspace", Input: input}) + require.NoError(t, err) + require.True(t, createCalled, "CreateFn should be called for allowed template") + // We don't assert resp.IsError here because CreateWorkspace + // does additional work (asOwner, workspace lookup) that + // depends on full RBAC setup. The key assertion is that + // the allowlist gate passed and CreateFn was invoked. + _ = resp + }) + + t.Run("Disallowed", func(t *testing.T) { + createCalled := false + tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{ + DB: db, + OwnerID: user.ID, + AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} }, + CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) { + createCalled = true + t.Fatal("CreateFn should not be called for blocked template") + return codersdk.Workspace{}, nil + }, + }) + input := `{"template_id":"` + t1.ID.String() + `"}` + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8", Name: "create_workspace", Input: input}) + require.NoError(t, err) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, "template not available for chat workspaces") + require.False(t, createCalled, "CreateFn should not be called for blocked template") + }) + }) +} diff --git a/coderd/x/chatd/chattool/readtemplate.go b/coderd/x/chatd/chattool/readtemplate.go index beae79ce46..7cc66ff569 100644 --- a/coderd/x/chatd/chattool/readtemplate.go +++ b/coderd/x/chatd/chattool/readtemplate.go @@ -14,8 +14,9 @@ import ( // ReadTemplateOptions configures the read_template tool. type ReadTemplateOptions struct { - DB database.Store - OwnerID uuid.UUID + DB database.Store + OwnerID uuid.UUID + AllowedTemplateIDs func() map[uuid.UUID]bool } type readTemplateArgs struct { @@ -48,6 +49,10 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool { ), nil } + if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) { + return fantasy.NewTextErrorResponse("template not found"), nil + } + ctx, err = asOwner(ctx, options.DB, options.OwnerID) if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil diff --git a/codersdk/chats.go b/codersdk/chats.go index 1e6d712da3..5b31c3d9f8 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -425,6 +425,13 @@ func ParseChatWorkspaceTTL(s string) (time.Duration, error) { return d, nil } +// ChatTemplateAllowlist is the request and response body for the +// chat template allowlist configuration endpoint. An empty list +// means all templates are allowed. +type ChatTemplateAllowlist struct { + TemplateIDs []string `json:"template_ids"` +} + // ChatProviderConfigSource describes how a provider entry is sourced. type ChatProviderConfigSource string @@ -1444,6 +1451,33 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd return nil } +// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist. +func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil) + if err != nil { + return ChatTemplateAllowlist{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatTemplateAllowlist{}, ReadBodyAsError(res) + } + var resp ChatTemplateAllowlist + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// UpdateChatTemplateAllowlist updates the deployment-wide chat template allowlist. +func (c *ExperimentalClient) UpdateChatTemplateAllowlist(ctx context.Context, req ChatTemplateAllowlist) error { + res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/template-allowlist", req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // UpdateUserChatCustomPrompt updates the user's custom chat prompt. func (c *ExperimentalClient) UpdateUserChatCustomPrompt(ctx context.Context, req UserChatCustomPrompt) (UserChatCustomPrompt, error) { res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/user-prompt", req) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index d31c599eda..2d40b4e709 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3216,12 +3216,29 @@ class ExperimentalApiMethods { return response.data; }; + getChatTemplateAllowlist = + async (): Promise => { + const response = await this.axios.get( + "/api/experimental/chats/config/template-allowlist", + ); + return response.data; + }; + updateChatWorkspaceTTL = async ( req: TypesGen.UpdateChatWorkspaceTTLRequest, ): Promise => { await this.axios.put("/api/experimental/chats/config/workspace-ttl", req); }; + updateChatTemplateAllowlist = async ( + req: TypesGen.ChatTemplateAllowlist, + ): Promise => { + await this.axios.put( + "/api/experimental/chats/config/template-allowlist", + req, + ); + }; + getUserChatCustomPrompt = async (): Promise => { const response = await this.axios.get( diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 991d75e5aa..edb66c00d6 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -439,6 +439,22 @@ export const updateChatWorkspaceTTL = (queryClient: QueryClient) => ({ }, }); +const chatTemplateAllowlistKey = ["chat-template-allowlist"] as const; + +export const chatTemplateAllowlist = () => ({ + queryKey: chatTemplateAllowlistKey, + queryFn: () => API.experimental.getChatTemplateAllowlist(), +}); + +export const updateChatTemplateAllowlist = (queryClient: QueryClient) => ({ + mutationFn: API.experimental.updateChatTemplateAllowlist, + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: chatTemplateAllowlistKey, + }); + }, +}); + const chatUserCustomPromptKey = ["chat-user-custom-prompt"] as const; export const chatUserCustomPrompt = () => ({ diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 918ac53313..30408821f2 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1880,6 +1880,16 @@ export interface ChatSystemPrompt { readonly system_prompt: string; } +// From codersdk/chats.go +/** + * ChatTemplateAllowlist is the request and response body for the + * chat template allowlist configuration endpoint. An empty list + * means all templates are allowed. + */ +export interface ChatTemplateAllowlist { + readonly template_ids: readonly string[]; +} + // From codersdk/chats.go export interface ChatTextPart { readonly type: "text"; diff --git a/site/src/pages/AgentsPage/AgentSettingsPageView.stories.tsx b/site/src/pages/AgentsPage/AgentSettingsPageView.stories.tsx index 4ba83dd6b4..8fcfe30149 100644 --- a/site/src/pages/AgentsPage/AgentSettingsPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsPageView.stories.tsx @@ -1,4 +1,4 @@ -import { MockUserOwner } from "testHelpers/entities"; +import { MockTemplate, MockUserOwner } from "testHelpers/entities"; import { withAuthProvider, withDashboardProvider } from "testHelpers/storybook"; import type { Meta, StoryObj } from "@storybook/react-vite"; import { API } from "api/api"; @@ -170,6 +170,30 @@ const meta = { workspace_ttl_ms: 0, }); spyOn(API.experimental, "updateChatWorkspaceTTL").mockResolvedValue(); + spyOn(API.experimental, "getChatTemplateAllowlist").mockResolvedValue({ + template_ids: [], + }); + spyOn(API.experimental, "updateChatTemplateAllowlist").mockResolvedValue(); + spyOn(API, "getTemplates").mockResolvedValue([ + { + ...MockTemplate, + id: "abc-123", + name: "docker-dev", + display_name: "Docker Development", + }, + { + ...MockTemplate, + id: "def-456", + name: "kubernetes-prod", + display_name: "Kubernetes Production", + }, + { + ...MockTemplate, + id: "ghi-789", + name: "aws-windows", + display_name: "AWS Windows Desktop", + }, + ]); }, } satisfies Meta; @@ -837,3 +861,126 @@ export const NoWarningForCleanPrompt: Story = { expect(canvas.queryByText(/invisible Unicode/)).toBeNull(); }, }; + +// ── Templates tab stories ────────────────────────────────────── + +const manyTemplates = [ + { id: "t-01", name: "docker-dev", display_name: "Docker Development" }, + { + id: "t-02", + name: "kubernetes-prod", + display_name: "Kubernetes Production", + }, + { id: "t-03", name: "aws-windows", display_name: "AWS Windows Desktop" }, + { id: "t-04", name: "gcp-linux", display_name: "GCP Linux Workspace" }, + { id: "t-05", name: "azure-dotnet", display_name: "Azure .NET Environment" }, + { id: "t-06", name: "ml-jupyter", display_name: "ML Jupyter Notebook" }, + { + id: "t-07", + name: "data-eng-spark", + display_name: "Data Engineering (Spark)", + }, + { + id: "t-08", + name: "frontend-vite", + display_name: "Frontend (Vite + React)", + }, +].map((t) => ({ ...MockTemplate, ...t })); + +export const TemplateAllowlist: Story = { + args: { + activeSection: "templates", + canManageChatModelConfigs: true, + canSetSystemPrompt: true, + }, + beforeEach: () => { + // Track saved allowlist state across mock calls so the + // refetch after save returns the updated value. + let savedIDs: string[] = []; + + spyOn(API, "getTemplates").mockResolvedValue(manyTemplates); + spyOn(API.experimental, "getChatTemplateAllowlist").mockImplementation( + async () => ({ template_ids: savedIDs }), + ); + spyOn(API.experimental, "updateChatTemplateAllowlist").mockImplementation( + async (req) => { + savedIDs = [...req.template_ids]; + }, + ); + }, + play: async ({ canvasElement, step }) => { + const canvas = within(canvasElement); + + await step("starts empty", async () => { + // Status text confirms no restrictions. + await canvas.findByText(/no templates selected/i); + // Save is disabled — nothing to save. + const saveBtn = await canvas.findByRole("button", { name: "Save" }); + expect(saveBtn).toBeDisabled(); + }); + + await step("select one template and save", async () => { + // Open the combobox. + const input = canvas.getByPlaceholderText("Select templates..."); + await userEvent.click(input); + // Pick the first template from the dropdown. + await userEvent.click( + await canvas.findByRole("option", { name: "Docker Development" }), + ); + // Badge pill should appear and status should update. + await waitFor(() => { + expect(canvas.getByText("1 template selected")).toBeInTheDocument(); + }); + // Save should now be enabled. + const saveBtn = canvas.getByRole("button", { name: "Save" }); + expect(saveBtn).toBeEnabled(); + await userEvent.click(saveBtn); + await waitFor(() => { + expect( + API.experimental.updateChatTemplateAllowlist, + ).toHaveBeenCalledWith({ template_ids: ["t-01"] }); + }); + }); + + await step("add the remaining seven and save", async () => { + // Open the combobox again. + const input = canvas.getByLabelText("Select allowed templates"); + await userEvent.click(input); + // Select the other seven templates one by one. + for (const name of [ + "Kubernetes Production", + "AWS Windows Desktop", + "GCP Linux Workspace", + "Azure .NET Environment", + "ML Jupyter Notebook", + "Data Engineering (Spark)", + "Frontend (Vite + React)", + ]) { + await userEvent.click(await canvas.findByRole("option", { name })); + } + // All eight should now be selected. + await waitFor(() => { + expect(canvas.getByText("8 templates selected")).toBeInTheDocument(); + }); + // Save. + const saveBtn = canvas.getByRole("button", { name: "Save" }); + await userEvent.click(saveBtn); + await waitFor(() => { + expect( + API.experimental.updateChatTemplateAllowlist, + ).toHaveBeenLastCalledWith({ + template_ids: expect.arrayContaining([ + "t-01", + "t-02", + "t-03", + "t-04", + "t-05", + "t-06", + "t-07", + "t-08", + ]), + }); + }); + }); + }, +}; diff --git a/site/src/pages/AgentsPage/AgentSettingsPageView.tsx b/site/src/pages/AgentsPage/AgentSettingsPageView.tsx index 675fb1c805..1120f8cce5 100644 --- a/site/src/pages/AgentsPage/AgentSettingsPageView.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsPageView.tsx @@ -5,13 +5,16 @@ import { chatDesktopEnabled, chatModelConfigs, chatSystemPrompt, + chatTemplateAllowlist, chatUserCustomPrompt, chatWorkspaceTTL, updateChatDesktopEnabled, updateChatSystemPrompt, + updateChatTemplateAllowlist, updateChatWorkspaceTTL, updateUserChatCustomPrompt, } from "api/queries/chats"; +import { templates } from "api/queries/templates"; import { user } from "api/queries/users"; import type * as TypesGen from "api/typesGenerated"; import dayjs from "dayjs"; @@ -35,6 +38,10 @@ import { Alert } from "#/components/Alert/Alert"; import { AvatarData } from "#/components/Avatar/AvatarData"; import { Button } from "#/components/Button/Button"; import { Link } from "#/components/Link/Link"; +import { + MultiSelectCombobox, + type Option, +} from "#/components/MultiSelectCombobox/MultiSelectCombobox"; import { PaginationAmount } from "#/components/PaginationWidget/PaginationAmount"; import { PaginationWidgetBase } from "#/components/PaginationWidget/PaginationWidgetBase"; import { SearchField } from "#/components/SearchField/SearchField"; @@ -931,7 +938,147 @@ export const AgentSettingsPageView: FC = ({ {activeSection === "insights" && canManageChatModelConfigs && ( )} + {activeSection === "templates" && canManageChatModelConfigs && ( + + )} ); }; + +const TemplateAllowlistSection: FC = () => { + const queryClient = useQueryClient(); + + // Fetch all available templates. + const templatesQuery = useQuery(templates()); + + // Fetch current allowlist. + const allowlistQuery = useQuery(chatTemplateAllowlist()); + + const { + mutate: saveAllowlist, + isPending: isSaving, + isError: isSaveError, + } = useMutation(updateChatTemplateAllowlist(queryClient)); + + const [localSelection, setLocalSelection] = useState(null); + + // Map all templates to MultiSelectCombobox options. + const allOptions: Option[] = (templatesQuery.data ?? []).map((t) => ({ + value: t.id, + label: t.display_name || t.name, + icon: t.icon, + })); + + // Build a lookup from template ID to Option for resolving server IDs. + const optionsByID = new Map(allOptions.map((o) => [o.value, o])); + + // Resolve the server-side allowlist IDs into Option objects. + const serverSelection: Option[] = (allowlistQuery.data?.template_ids ?? []) + .map((id) => optionsByID.get(id)) + .filter((o) => o !== undefined); + + const currentSelection = localSelection ?? serverSelection; + + const serverSet = new Set(serverSelection.map((o) => o.value)); + const isDirty = + localSelection !== null && + (localSelection.length !== serverSet.size || + localSelection.some((o) => !serverSet.has(o.value))); + + const handleSave = (event: FormEvent) => { + event.preventDefault(); + if (!isDirty) return; + saveAllowlist( + { template_ids: currentSelection.map((o) => o.value) }, + { onSuccess: () => setLocalSelection(null) }, + ); + }; + + const isLoading = templatesQuery.isLoading || allowlistQuery.isLoading; + + return ( +
+ } + /> + + {isLoading && ( +
+ +
+ )} + + {!isLoading && (templatesQuery.error || allowlistQuery.error) && ( +
+

+ Failed to load template data. +

+ +
+ )} + + {!isLoading && !templatesQuery.error && !allowlistQuery.error && ( +
void handleSave(event)} + > + o.value).join(",")} + inputProps={{ "aria-label": "Select allowed templates" }} + options={allOptions} + defaultOptions={currentSelection} + value={currentSelection} + onChange={setLocalSelection} + placeholder="Select templates..." + emptyIndicator={ +

+ No templates found. +

+ } + disabled={isSaving} + hidePlaceholderWhenSelected + data-testid="template-allowlist-select" + /> +

+ {currentSelection.length > 0 + ? `${currentSelection.length} template${currentSelection.length !== 1 ? "s" : ""} selected` + : "No templates selected \u2014 all templates are available"} +

+ +
+ +
+ + {isSaveError && ( +

+ Failed to save template allowlist. +

+ )} + + )} +
+ ); +}; diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx index 2a3746ba4f..2dd13a9181 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx @@ -22,6 +22,7 @@ import { GitPullRequestClosedIcon, GitPullRequestDraftIcon, KeyRoundIcon, + LayoutTemplateIcon, Loader2Icon, PanelLeftCloseIcon, PauseIcon, @@ -987,6 +988,14 @@ export const AgentsSidebar: FC = (props) => { /> {isAdmin && ( <> +