feat: add deployment-wide template allowlist for chats (#23262)

- Stores a deployment-wide agents template allowlist in `site_configs`
(`agents_template_allowlist`)
- Adds `GET/PUT /api/experimental/chats/config/template-allowlist`
endpoints
- Filters `list_templates`, `read_template`, and `create_workspace` chat
tools by allowlist, if defined (empty=all allowed)
- Add "Templates" admin settings tab in Agents UI ([what it looks
like](https://624de63c6aacee003aa84340-sitjilsyrr.chromatic.com/?path=/story/pages-agentspage-agentsettingspageview--template-allowlist))

> 🤖 This PR was created with the help of Coder Agents, and has been
reviewed by my human. 🧑‍💻
This commit is contained in:
Cian Johnston
2026-03-25 15:19:17 +00:00
committed by GitHub
parent c0ab22dc88
commit 796872f4de
23 changed files with 1045 additions and 10 deletions
+2
View File
@@ -1186,6 +1186,8 @@ func New(options *Options) *API {
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
})
// TODO(cian): place under /api/experimental/chats/config
r.Route("/providers", func(r chi.Router) {
+18
View File
@@ -2674,6 +2674,17 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
// GetChatTemplateAllowlist requires deployment-config read permission,
// unlike the peer getters (GetChatDesktopEnabled, etc.) which only
// check actor presence. The allowlist is admin-configuration that
// should not be readable by non-admin users via the HTTP API.
func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return "", err
}
return q.db.GetChatTemplateAllowlist(ctx)
}
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
@@ -6812,6 +6823,13 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
return q.db.UpsertChatSystemPrompt(ctx, value)
}
func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
}
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
+8
View File
@@ -656,6 +656,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatTemplateAllowlist(gomock.Any()).Return("", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatWorkspaceTTL(gomock.Any()).Return("1h", nil).AnyTimes()
check.Args().Asserts()
@@ -873,6 +877,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatTemplateAllowlist(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatWorkspaceTTL(gomock.Any(), "1h").Return(nil).AnyTimes()
check.Args("1h").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
+16
View File
@@ -1208,6 +1208,14 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
return r0, r1
}
func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetChatTemplateAllowlist(ctx)
m.queryLatencies.WithLabelValues("GetChatTemplateAllowlist").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTemplateAllowlist").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
@@ -4808,6 +4816,14 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
return r0
}
func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
start := time.Now()
r0 := m.s.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
m.queryLatencies.WithLabelValues("UpsertChatTemplateAllowlist").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTemplateAllowlist").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
+29
View File
@@ -2223,6 +2223,21 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
}
// GetChatTemplateAllowlist mocks base method.
func (m *MockStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatTemplateAllowlist", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatTemplateAllowlist indicates an expected call of GetChatTemplateAllowlist.
func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx)
}
// GetChatUsageLimitConfig mocks base method.
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
@@ -9013,6 +9028,20 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
}
// UpsertChatTemplateAllowlist mocks base method.
func (m *MockStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatTemplateAllowlist", ctx, templateAllowlist)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatTemplateAllowlist indicates an expected call of UpsertChatTemplateAllowlist.
func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowlist any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist)
}
// UpsertChatUsageLimitConfig mocks base method.
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
+4
View File
@@ -254,6 +254,9 @@ type sqlcQuerier interface {
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
GetChatSystemPrompt(ctx context.Context) (string, error)
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
// Returns an empty string when no allowlist has been configured (all templates allowed).
GetChatTemplateAllowlist(ctx context.Context) (string, error)
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
@@ -933,6 +936,7 @@ type sqlcQuerier interface {
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatSystemPrompt(ctx context.Context, value string) error
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
+24
View File
@@ -17512,6 +17512,20 @@ func (q *sqlQuerier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return chat_system_prompt, err
}
const getChatTemplateAllowlist = `-- name: GetChatTemplateAllowlist :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist
`
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
// Returns an empty string when no allowlist has been configured (all templates allowed).
func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getChatTemplateAllowlist)
var template_allowlist string
err := row.Scan(&template_allowlist)
return template_allowlist, err
}
const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one
SELECT
COALESCE(
@@ -17743,6 +17757,16 @@ func (q *sqlQuerier) UpsertChatSystemPrompt(ctx context.Context, value string) e
return err
}
const upsertChatTemplateAllowlist = `-- name: UpsertChatTemplateAllowlist :exec
INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_template_allowlist'
`
func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
_, err := q.db.ExecContext(ctx, upsertChatTemplateAllowlist, templateAllowlist)
return err
}
const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec
INSERT INTO site_configs (key, value)
VALUES ('agents_workspace_ttl', $1::text)
+10
View File
@@ -161,6 +161,12 @@ SET value = CASE
END
WHERE site_configs.key = 'agents_desktop_enabled';
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
-- Returns an empty string when no allowlist has been configured (all templates allowed).
-- name: GetChatTemplateAllowlist :one
SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist;
-- name: GetChatWorkspaceTTL :one
-- Returns the global TTL for chat workspaces as a Go duration string.
-- Returns "0s" (disabled) when no value has been configured.
@@ -170,6 +176,10 @@ SELECT
'0s'
)::text AS workspace_ttl;
-- name: UpsertChatTemplateAllowlist :exec
INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', @template_allowlist)
ON CONFLICT (key) DO UPDATE SET value = @template_allowlist WHERE site_configs.key = 'agents_template_allowlist';
-- name: UpsertChatWorkspaceTTL :exec
INSERT INTO site_configs (key, value)
VALUES ('agents_workspace_ttl', @workspace_ttl::text)
+148
View File
@@ -2777,6 +2777,154 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
httpapi.ResourceNotFound(rw)
return
}
raw, err := api.Database.GetChatTemplateAllowlist(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching chat template allowlist.",
Detail: err.Error(),
})
return
}
ids, parseErr := parseChatTemplateAllowlist(raw)
if parseErr != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Stored template allowlist is corrupt.",
Detail: parseErr.Error(),
})
return
}
resp := codersdk.ChatTemplateAllowlist{
TemplateIDs: ids,
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.ResourceNotFound(rw)
return
}
var req codersdk.ChatTemplateAllowlist
if !httpapi.Read(ctx, rw, r, &req) {
return
}
// Validate all entries are valid UUIDs and deduplicate.
seen := make(map[string]struct{}, len(req.TemplateIDs))
deduped := make([]string, 0, len(req.TemplateIDs))
for _, id := range req.TemplateIDs {
parsed, err := uuid.Parse(id)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid template ID in allowlist.",
Detail: fmt.Sprintf("%q is not a valid UUID.", id),
})
return
}
// Canonicalize to lowercase so deduplication is
// case-insensitive and stored values are consistent.
canonical := parsed.String()
if _, ok := seen[canonical]; !ok {
seen[canonical] = struct{}{}
deduped = append(deduped, canonical)
}
}
// Convert to UUIDs for the database query.
parsedUUIDs := make([]uuid.UUID, len(deduped))
for i, s := range deduped {
// Already validated above, safe to ignore error.
parsedUUIDs[i], _ = uuid.Parse(s)
}
raw, err := json.Marshal(deduped)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error encoding template allowlist.",
Detail: err.Error(),
})
return
}
err = api.Database.InTx(func(tx database.Store) error {
// Verify all IDs refer to existing, non-deprecated templates
// in a single query.
if len(parsedUUIDs) > 0 {
found, err := tx.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
IDs: parsedUUIDs,
Deprecated: sql.NullBool{
Bool: false,
Valid: true,
},
})
if err != nil {
return xerrors.Errorf("fetch templates: %w", err)
}
if len(found) != len(parsedUUIDs) {
foundSet := make(map[uuid.UUID]struct{}, len(found))
for _, t := range found {
foundSet[t.ID] = struct{}{}
}
var missing []string
for _, id := range parsedUUIDs {
if _, ok := foundSet[id]; !ok {
missing = append(missing, id.String())
}
}
return xerrors.Errorf("templates not found or deprecated: %s", strings.Join(missing, ", "))
}
}
return tx.UpsertChatTemplateAllowlist(ctx, string(raw))
}, nil)
if err != nil {
// If the error mentions "not found or deprecated", it's a
// validation failure, not an internal error.
if strings.Contains(err.Error(), "not found or deprecated") {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "One or more templates not found or deprecated.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating chat template allowlist.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// parseChatTemplateAllowlist parses the raw JSON string from the
// database into a list of template ID strings. Returns an empty
// slice when the value is empty. Returns an error when the stored
// JSON is corrupt or otherwise cannot be unmarshalled.
func parseChatTemplateAllowlist(raw string) ([]string, error) {
if raw == "" {
return []string{}, nil
}
var ids []string
if err := json.Unmarshal([]byte(raw), &ids); err != nil {
return nil, xerrors.Errorf("unmarshal template allowlist: %w", err)
}
if ids == nil {
return []string{}, nil
}
return ids, nil
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
+128
View File
@@ -5161,6 +5161,134 @@ func TestUserChatCompactionThresholds(t *testing.T) {
})
}
//nolint:tparallel // Subtests share a single coderdtest instance and run sequentially.
func TestChatTemplateAllowlist(t *testing.T) {
t.Parallel()
// Shared setup: one coderdtest instance with two real templates.
// Subtests that need valid template IDs use these.
client, store := newChatClientWithDatabase(t)
admin := coderdtest.CreateFirstUser(t, client.Client)
tmpl1 := dbgen.Template(t, store, database.Template{
OrganizationID: admin.OrganizationID,
CreatedBy: admin.UserID,
})
tmpl2 := dbgen.Template(t, store, database.Template{
OrganizationID: admin.OrganizationID,
CreatedBy: admin.UserID,
})
deprecatedTmpl := dbgen.Template(t, store, database.Template{
OrganizationID: admin.OrganizationID,
CreatedBy: admin.UserID,
})
//nolint:gocritic // Owner context needed to deprecate the template in test setup.
ownerRoles, err := rbac.RoleIdentifiers{rbac.RoleOwner()}.Expand()
require.NoError(t, err)
err = store.UpdateTemplateAccessControlByID(dbauthz.As(context.Background(), rbac.Subject{
ID: "owner",
Roles: rbac.Roles(ownerRoles),
Scope: rbac.ExpandableScope(rbac.ScopeAll),
}), database.UpdateTemplateAccessControlByIDParams{
ID: deprecatedTmpl.ID,
Deprecated: "this template is deprecated",
})
require.NoError(t, err, "deprecate template")
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
resp, err := client.GetChatTemplateAllowlist(ctx)
require.NoError(t, err)
require.Empty(t, resp.TemplateIDs)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("AdminCanSet", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
ids := []string{tmpl1.ID.String(), tmpl2.ID.String()}
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: ids})
require.NoError(t, err)
resp, err := client.GetChatTemplateAllowlist(ctx)
require.NoError(t, err)
require.ElementsMatch(t, ids, resp.TemplateIDs)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("AdminCanClear", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{}})
require.NoError(t, err)
resp, err := client.GetChatTemplateAllowlist(ctx)
require.NoError(t, err)
require.Empty(t, resp.TemplateIDs)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("NonAdminReadFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
_, err := memberClient.GetChatTemplateAllowlist(ctx)
requireSDKError(t, err, http.StatusNotFound)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("NonAdminWriteFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
// Uses a random UUID — hits 404 before template validation.
err := memberClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
requireSDKError(t, err, http.StatusNotFound)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("UnauthenticatedFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
anonClient := codersdk.NewExperimentalClient(codersdk.New(client.URL))
// Uses a random UUID — hits 401 before template validation.
err := anonClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
requireSDKError(t, err, http.StatusUnauthorized)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("InvalidUUIDRejected", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{"not-a-uuid"}})
requireSDKError(t, err, http.StatusBadRequest)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("NonexistentTemplateRejected", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
requireSDKError(t, err, http.StatusBadRequest)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("DeprecatedTemplateRejected", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{
TemplateIDs: []string{deprecatedTmpl.ID.String()},
})
requireSDKError(t, err, http.StatusBadRequest)
})
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
t.Run("DeduplicatesIDs", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
id := tmpl1.ID.String()
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{
TemplateIDs: []string{id, id, id},
})
require.NoError(t, err)
resp, err := client.GetChatTemplateAllowlist(ctx)
require.NoError(t, err)
require.Len(t, resp.TemplateIDs, 1)
require.Equal(t, id, resp.TemplateIDs[0])
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+50 -4
View File
@@ -3371,6 +3371,45 @@ func (p *Server) runChat(
GetWorkspaceConn: workspaceCtx.getWorkspaceConn,
}),
}
// getAllowedTemplateIDs returns the current deployment-wide
// template allowlist, re-reading from the database on each call
// so that admin changes take effect without restarting the chat.
// Returns nil (= all allowed) on errors to fail open.
getAllowedTemplateIDs := func() map[uuid.UUID]bool {
raw, err := p.db.GetChatTemplateAllowlist(ctx)
if err != nil {
p.logger.Error(ctx, "failed to load template allowlist, all templates will be allowed", slog.Error(err))
return nil
}
if raw == "" {
return nil
}
var ids []string
if jsonErr := json.Unmarshal([]byte(raw), &ids); jsonErr != nil {
// Note: the API endpoint (GET /template-allowlist) returns
// HTTP 500 for corrupt JSON, giving admins visibility into
// the problem. The runtime path here deliberately fails open
// so that a corrupt allowlist doesn't block all chats.
p.logger.Error(ctx, "failed to parse template allowlist JSON, all templates will be allowed",
slog.F("raw", raw), slog.Error(jsonErr))
return nil
}
allowlist := make(map[uuid.UUID]bool, len(ids))
for _, s := range ids {
if id, parseErr := uuid.Parse(s); parseErr == nil {
allowlist[id] = true
} else {
p.logger.Warn(ctx, "ignoring invalid UUID in template allowlist",
slog.F("value", s), slog.Error(parseErr))
}
}
if len(ids) > 0 && len(allowlist) == 0 {
p.logger.Error(ctx, "all UUIDs in template allowlist were invalid, all templates will be allowed",
slog.F("count", len(ids)))
return nil
}
return allowlist
}
// Only root chats (not delegated subagents) get workspace
// provisioning and subagent tools. Child agents must not
// create workspaces or spawn further subagents — they should
@@ -3379,12 +3418,14 @@ func (p *Server) runChat(
// Workspace provisioning tools.
tools = append(tools,
chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: p.db,
OwnerID: chat.OwnerID,
DB: p.db,
OwnerID: chat.OwnerID,
AllowedTemplateIDs: getAllowedTemplateIDs,
}),
chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: p.db,
OwnerID: chat.OwnerID,
DB: p.db,
OwnerID: chat.OwnerID,
AllowedTemplateIDs: getAllowedTemplateIDs,
}),
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: p.db,
@@ -3395,7 +3436,12 @@ func (p *Server) runChat(
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
WorkspaceMu: &workspaceMu,
Logger: p.logger,
AllowedTemplateIDs: getAllowedTemplateIDs,
}),
// StartWorkspace intentionally does not enforce the
// template allowlist. The allowlist restricts creation
// of new workspaces only — existing workspaces can
// be restarted regardless of allowlist changes.
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: p.db,
OwnerID: chat.OwnerID,
+15
View File
@@ -5,6 +5,7 @@ import (
"unicode/utf8"
"charm.land/fantasy"
"github.com/google/uuid"
)
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
@@ -31,3 +32,17 @@ func truncateRunes(value string, maxLen int) string {
}
return string(runes[:maxLen])
}
// isTemplateAllowed checks whether a template ID is permitted by the
// configured allowlist. A nil function or an empty allowlist means
// all templates are allowed.
func isTemplateAllowed(getAllowlist func() map[uuid.UUID]bool, id uuid.UUID) bool {
if getAllowlist == nil {
return true
}
allowlist := getAllowlist()
if len(allowlist) == 0 {
return true
}
return allowlist[id]
}
+5 -1
View File
@@ -67,6 +67,7 @@ type CreateWorkspaceOptions struct {
AgentInactiveDisconnectTimeout time.Duration
WorkspaceMu *sync.Mutex
Logger slog.Logger
AllowedTemplateIDs func() map[uuid.UUID]bool
}
type createWorkspaceArgs struct {
@@ -106,6 +107,10 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
), nil
}
if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) {
return fantasy.NewTextErrorResponse("template not available for chat workspaces; use list_templates to find allowed templates"), nil
}
// Serialize workspace creation to prevent parallel
// tool calls from creating duplicate workspaces.
if options.WorkspaceMu != nil {
@@ -121,7 +126,6 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
if done {
return toolResponse(existing), nil
}
ownerID := options.OwnerID
// Set up dbauthz context for DB lookups.
+12 -2
View File
@@ -3,6 +3,8 @@ package chattool
import (
"context"
"database/sql"
"maps"
"slices"
"sort"
"strings"
@@ -20,8 +22,9 @@ const listTemplatesPageSize = 10
// ListTemplatesOptions configures the list_templates tool.
type ListTemplatesOptions struct {
DB database.Store
OwnerID uuid.UUID
DB database.Store
OwnerID uuid.UUID
AllowedTemplateIDs func() map[uuid.UUID]bool
}
type listTemplatesArgs struct {
@@ -63,6 +66,13 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
filterParams.FuzzyName = query
}
var allowlist map[uuid.UUID]bool
if options.AllowedTemplateIDs != nil {
allowlist = options.AllowedTemplateIDs()
}
if len(allowlist) > 0 {
filterParams.IDs = slices.Collect(maps.Keys(allowlist))
}
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
@@ -0,0 +1,188 @@
package chattool_test
import (
"context"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
//nolint:tparallel,paralleltest // Subtests share a single DB and run sequentially.
func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
t1 := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "template-alpha",
})
t2 := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "template-beta",
})
t.Run("ListTemplates", func(t *testing.T) {
t.Run("NoAllowlist", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
OwnerID: user.ID,
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c1", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Len(t, templates, 2)
})
t.Run("EmptyAllowlist", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c2", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Len(t, templates, 2)
})
t.Run("OneMatch", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c3", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Len(t, templates, 1)
m := templates[0].(map[string]any)
require.Equal(t, t1.ID.String(), m["id"].(string))
})
t.Run("NoMatches", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c4", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Empty(t, templates)
})
})
t.Run("ReadTemplate", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
})
input := `{"template_id":"` + t1.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c5", Name: "read_template", Input: input})
require.NoError(t, err)
require.False(t, resp.IsError)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
tmplInfo := result["template"].(map[string]any)
require.Equal(t, t1.ID.String(), tmplInfo["id"].(string))
})
t.Run("Disallowed", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
})
input := `{"template_id":"` + t2.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c6", Name: "read_template", Input: input})
require.NoError(t, err)
require.True(t, resp.IsError)
require.Contains(t, resp.Content, "not found")
})
t.Run("NoAllowlist", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
OwnerID: user.ID,
})
input := `{"template_id":"` + t2.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c7", Name: "read_template", Input: input})
require.NoError(t, err)
require.False(t, resp.IsError)
})
})
t.Run("CreateWorkspace", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
createCalled := false
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
createCalled = true
return codersdk.Workspace{}, nil
},
})
input := `{"template_id":"` + t1.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8a", Name: "create_workspace", Input: input})
require.NoError(t, err)
require.True(t, createCalled, "CreateFn should be called for allowed template")
// We don't assert resp.IsError here because CreateWorkspace
// does additional work (asOwner, workspace lookup) that
// depends on full RBAC setup. The key assertion is that
// the allowlist gate passed and CreateFn was invoked.
_ = resp
})
t.Run("Disallowed", func(t *testing.T) {
createCalled := false
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: db,
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
createCalled = true
t.Fatal("CreateFn should not be called for blocked template")
return codersdk.Workspace{}, nil
},
})
input := `{"template_id":"` + t1.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8", Name: "create_workspace", Input: input})
require.NoError(t, err)
require.True(t, resp.IsError)
require.Contains(t, resp.Content, "template not available for chat workspaces")
require.False(t, createCalled, "CreateFn should not be called for blocked template")
})
})
}
+7 -2
View File
@@ -14,8 +14,9 @@ import (
// ReadTemplateOptions configures the read_template tool.
type ReadTemplateOptions struct {
DB database.Store
OwnerID uuid.UUID
DB database.Store
OwnerID uuid.UUID
AllowedTemplateIDs func() map[uuid.UUID]bool
}
type readTemplateArgs struct {
@@ -48,6 +49,10 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
), nil
}
if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) {
return fantasy.NewTextErrorResponse("template not found"), nil
}
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
+34
View File
@@ -425,6 +425,13 @@ func ParseChatWorkspaceTTL(s string) (time.Duration, error) {
return d, nil
}
// ChatTemplateAllowlist is the request and response body for the
// chat template allowlist configuration endpoint. An empty list
// means all templates are allowed.
type ChatTemplateAllowlist struct {
TemplateIDs []string `json:"template_ids"`
}
// ChatProviderConfigSource describes how a provider entry is sourced.
type ChatProviderConfigSource string
@@ -1444,6 +1451,33 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd
return nil
}
// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist.
func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil)
if err != nil {
return ChatTemplateAllowlist{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatTemplateAllowlist{}, ReadBodyAsError(res)
}
var resp ChatTemplateAllowlist
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// UpdateChatTemplateAllowlist updates the deployment-wide chat template allowlist.
func (c *ExperimentalClient) UpdateChatTemplateAllowlist(ctx context.Context, req ChatTemplateAllowlist) error {
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/template-allowlist", req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// UpdateUserChatCustomPrompt updates the user's custom chat prompt.
func (c *ExperimentalClient) UpdateUserChatCustomPrompt(ctx context.Context, req UserChatCustomPrompt) (UserChatCustomPrompt, error) {
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/user-prompt", req)
+17
View File
@@ -3216,12 +3216,29 @@ class ExperimentalApiMethods {
return response.data;
};
getChatTemplateAllowlist =
async (): Promise<TypesGen.ChatTemplateAllowlist> => {
const response = await this.axios.get<TypesGen.ChatTemplateAllowlist>(
"/api/experimental/chats/config/template-allowlist",
);
return response.data;
};
updateChatWorkspaceTTL = async (
req: TypesGen.UpdateChatWorkspaceTTLRequest,
): Promise<void> => {
await this.axios.put("/api/experimental/chats/config/workspace-ttl", req);
};
updateChatTemplateAllowlist = async (
req: TypesGen.ChatTemplateAllowlist,
): Promise<void> => {
await this.axios.put(
"/api/experimental/chats/config/template-allowlist",
req,
);
};
getUserChatCustomPrompt =
async (): Promise<TypesGen.UserChatCustomPrompt> => {
const response = await this.axios.get<TypesGen.UserChatCustomPrompt>(
+16
View File
@@ -439,6 +439,22 @@ export const updateChatWorkspaceTTL = (queryClient: QueryClient) => ({
},
});
const chatTemplateAllowlistKey = ["chat-template-allowlist"] as const;
export const chatTemplateAllowlist = () => ({
queryKey: chatTemplateAllowlistKey,
queryFn: () => API.experimental.getChatTemplateAllowlist(),
});
export const updateChatTemplateAllowlist = (queryClient: QueryClient) => ({
mutationFn: API.experimental.updateChatTemplateAllowlist,
onSuccess: async () => {
await queryClient.invalidateQueries({
queryKey: chatTemplateAllowlistKey,
});
},
});
const chatUserCustomPromptKey = ["chat-user-custom-prompt"] as const;
export const chatUserCustomPrompt = () => ({
+10
View File
@@ -1880,6 +1880,16 @@ export interface ChatSystemPrompt {
readonly system_prompt: string;
}
// From codersdk/chats.go
/**
* ChatTemplateAllowlist is the request and response body for the
* chat template allowlist configuration endpoint. An empty list
* means all templates are allowed.
*/
export interface ChatTemplateAllowlist {
readonly template_ids: readonly string[];
}
// From codersdk/chats.go
export interface ChatTextPart {
readonly type: "text";
@@ -1,4 +1,4 @@
import { MockUserOwner } from "testHelpers/entities";
import { MockTemplate, MockUserOwner } from "testHelpers/entities";
import { withAuthProvider, withDashboardProvider } from "testHelpers/storybook";
import type { Meta, StoryObj } from "@storybook/react-vite";
import { API } from "api/api";
@@ -170,6 +170,30 @@ const meta = {
workspace_ttl_ms: 0,
});
spyOn(API.experimental, "updateChatWorkspaceTTL").mockResolvedValue();
spyOn(API.experimental, "getChatTemplateAllowlist").mockResolvedValue({
template_ids: [],
});
spyOn(API.experimental, "updateChatTemplateAllowlist").mockResolvedValue();
spyOn(API, "getTemplates").mockResolvedValue([
{
...MockTemplate,
id: "abc-123",
name: "docker-dev",
display_name: "Docker Development",
},
{
...MockTemplate,
id: "def-456",
name: "kubernetes-prod",
display_name: "Kubernetes Production",
},
{
...MockTemplate,
id: "ghi-789",
name: "aws-windows",
display_name: "AWS Windows Desktop",
},
]);
},
} satisfies Meta<typeof AgentSettingsPageView>;
@@ -837,3 +861,126 @@ export const NoWarningForCleanPrompt: Story = {
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
},
};
// ── Templates tab stories ──────────────────────────────────────
const manyTemplates = [
{ id: "t-01", name: "docker-dev", display_name: "Docker Development" },
{
id: "t-02",
name: "kubernetes-prod",
display_name: "Kubernetes Production",
},
{ id: "t-03", name: "aws-windows", display_name: "AWS Windows Desktop" },
{ id: "t-04", name: "gcp-linux", display_name: "GCP Linux Workspace" },
{ id: "t-05", name: "azure-dotnet", display_name: "Azure .NET Environment" },
{ id: "t-06", name: "ml-jupyter", display_name: "ML Jupyter Notebook" },
{
id: "t-07",
name: "data-eng-spark",
display_name: "Data Engineering (Spark)",
},
{
id: "t-08",
name: "frontend-vite",
display_name: "Frontend (Vite + React)",
},
].map((t) => ({ ...MockTemplate, ...t }));
export const TemplateAllowlist: Story = {
args: {
activeSection: "templates",
canManageChatModelConfigs: true,
canSetSystemPrompt: true,
},
beforeEach: () => {
// Track saved allowlist state across mock calls so the
// refetch after save returns the updated value.
let savedIDs: string[] = [];
spyOn(API, "getTemplates").mockResolvedValue(manyTemplates);
spyOn(API.experimental, "getChatTemplateAllowlist").mockImplementation(
async () => ({ template_ids: savedIDs }),
);
spyOn(API.experimental, "updateChatTemplateAllowlist").mockImplementation(
async (req) => {
savedIDs = [...req.template_ids];
},
);
},
play: async ({ canvasElement, step }) => {
const canvas = within(canvasElement);
await step("starts empty", async () => {
// Status text confirms no restrictions.
await canvas.findByText(/no templates selected/i);
// Save is disabled — nothing to save.
const saveBtn = await canvas.findByRole("button", { name: "Save" });
expect(saveBtn).toBeDisabled();
});
await step("select one template and save", async () => {
// Open the combobox.
const input = canvas.getByPlaceholderText("Select templates...");
await userEvent.click(input);
// Pick the first template from the dropdown.
await userEvent.click(
await canvas.findByRole("option", { name: "Docker Development" }),
);
// Badge pill should appear and status should update.
await waitFor(() => {
expect(canvas.getByText("1 template selected")).toBeInTheDocument();
});
// Save should now be enabled.
const saveBtn = canvas.getByRole("button", { name: "Save" });
expect(saveBtn).toBeEnabled();
await userEvent.click(saveBtn);
await waitFor(() => {
expect(
API.experimental.updateChatTemplateAllowlist,
).toHaveBeenCalledWith({ template_ids: ["t-01"] });
});
});
await step("add the remaining seven and save", async () => {
// Open the combobox again.
const input = canvas.getByLabelText("Select allowed templates");
await userEvent.click(input);
// Select the other seven templates one by one.
for (const name of [
"Kubernetes Production",
"AWS Windows Desktop",
"GCP Linux Workspace",
"Azure .NET Environment",
"ML Jupyter Notebook",
"Data Engineering (Spark)",
"Frontend (Vite + React)",
]) {
await userEvent.click(await canvas.findByRole("option", { name }));
}
// All eight should now be selected.
await waitFor(() => {
expect(canvas.getByText("8 templates selected")).toBeInTheDocument();
});
// Save.
const saveBtn = canvas.getByRole("button", { name: "Save" });
await userEvent.click(saveBtn);
await waitFor(() => {
expect(
API.experimental.updateChatTemplateAllowlist,
).toHaveBeenLastCalledWith({
template_ids: expect.arrayContaining([
"t-01",
"t-02",
"t-03",
"t-04",
"t-05",
"t-06",
"t-07",
"t-08",
]),
});
});
});
},
};
@@ -5,13 +5,16 @@ import {
chatDesktopEnabled,
chatModelConfigs,
chatSystemPrompt,
chatTemplateAllowlist,
chatUserCustomPrompt,
chatWorkspaceTTL,
updateChatDesktopEnabled,
updateChatSystemPrompt,
updateChatTemplateAllowlist,
updateChatWorkspaceTTL,
updateUserChatCustomPrompt,
} from "api/queries/chats";
import { templates } from "api/queries/templates";
import { user } from "api/queries/users";
import type * as TypesGen from "api/typesGenerated";
import dayjs from "dayjs";
@@ -35,6 +38,10 @@ import { Alert } from "#/components/Alert/Alert";
import { AvatarData } from "#/components/Avatar/AvatarData";
import { Button } from "#/components/Button/Button";
import { Link } from "#/components/Link/Link";
import {
MultiSelectCombobox,
type Option,
} from "#/components/MultiSelectCombobox/MultiSelectCombobox";
import { PaginationAmount } from "#/components/PaginationWidget/PaginationAmount";
import { PaginationWidgetBase } from "#/components/PaginationWidget/PaginationWidgetBase";
import { SearchField } from "#/components/SearchField/SearchField";
@@ -931,7 +938,147 @@ export const AgentSettingsPageView: FC<AgentSettingsPageViewProps> = ({
{activeSection === "insights" && canManageChatModelConfigs && (
<InsightsContent />
)}
{activeSection === "templates" && canManageChatModelConfigs && (
<TemplateAllowlistSection />
)}
</div>
</div>
);
};
const TemplateAllowlistSection: FC = () => {
const queryClient = useQueryClient();
// Fetch all available templates.
const templatesQuery = useQuery(templates());
// Fetch current allowlist.
const allowlistQuery = useQuery(chatTemplateAllowlist());
const {
mutate: saveAllowlist,
isPending: isSaving,
isError: isSaveError,
} = useMutation(updateChatTemplateAllowlist(queryClient));
const [localSelection, setLocalSelection] = useState<Option[] | null>(null);
// Map all templates to MultiSelectCombobox options.
const allOptions: Option[] = (templatesQuery.data ?? []).map((t) => ({
value: t.id,
label: t.display_name || t.name,
icon: t.icon,
}));
// Build a lookup from template ID to Option for resolving server IDs.
const optionsByID = new Map(allOptions.map((o) => [o.value, o]));
// Resolve the server-side allowlist IDs into Option objects.
const serverSelection: Option[] = (allowlistQuery.data?.template_ids ?? [])
.map((id) => optionsByID.get(id))
.filter((o) => o !== undefined);
const currentSelection = localSelection ?? serverSelection;
const serverSet = new Set(serverSelection.map((o) => o.value));
const isDirty =
localSelection !== null &&
(localSelection.length !== serverSet.size ||
localSelection.some((o) => !serverSet.has(o.value)));
const handleSave = (event: FormEvent) => {
event.preventDefault();
if (!isDirty) return;
saveAllowlist(
{ template_ids: currentSelection.map((o) => o.value) },
{ onSuccess: () => setLocalSelection(null) },
);
};
const isLoading = templatesQuery.isLoading || allowlistQuery.isLoading;
return (
<div className="space-y-6">
<SectionHeader
label="Templates"
description="Restrict which templates agents can use to create workspaces. When no templates are selected, all templates are available."
badge={<AdminBadge />}
/>
{isLoading && (
<div
role="status"
aria-label="Loading templates"
className="flex min-h-[120px] items-center justify-center"
>
<Spinner size="lg" loading className="text-content-secondary" />
</div>
)}
{!isLoading && (templatesQuery.error || allowlistQuery.error) && (
<div className="flex min-h-[120px] flex-col items-center justify-center gap-4 text-center">
<p className="m-0 text-sm text-content-secondary">
Failed to load template data.
</p>
<Button
variant="outline"
size="sm"
type="button"
onClick={() => {
void templatesQuery.refetch();
void allowlistQuery.refetch();
}}
>
Retry
</Button>
</div>
)}
{!isLoading && !templatesQuery.error && !allowlistQuery.error && (
<form
className="space-y-3"
onSubmit={(event) => void handleSave(event)}
>
<MultiSelectCombobox
key={serverSelection.map((o) => o.value).join(",")}
inputProps={{ "aria-label": "Select allowed templates" }}
options={allOptions}
defaultOptions={currentSelection}
value={currentSelection}
onChange={setLocalSelection}
placeholder="Select templates..."
emptyIndicator={
<p className="text-center text-sm text-content-secondary">
No templates found.
</p>
}
disabled={isSaving}
hidePlaceholderWhenSelected
data-testid="template-allowlist-select"
/>
<p
aria-live="polite"
role="status"
className="m-0 text-xs text-content-secondary"
>
{currentSelection.length > 0
? `${currentSelection.length} template${currentSelection.length !== 1 ? "s" : ""} selected`
: "No templates selected \u2014 all templates are available"}
</p>
<div className="flex justify-end">
<Button size="sm" type="submit" disabled={isSaving || !isDirty}>
Save
</Button>
</div>
{isSaveError && (
<p role="alert" className="m-0 text-xs text-content-destructive">
Failed to save template allowlist.
</p>
)}
</form>
)}
</div>
);
};
@@ -22,6 +22,7 @@ import {
GitPullRequestClosedIcon,
GitPullRequestDraftIcon,
KeyRoundIcon,
LayoutTemplateIcon,
Loader2Icon,
PanelLeftCloseIcon,
PauseIcon,
@@ -987,6 +988,14 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
/>
{isAdmin && (
<>
<SettingsNavItem
icon={LayoutTemplateIcon}
label="Templates"
active={sidebarView.section === "templates"}
to="/agents/settings/templates"
state={location.state}
adminOnly
/>
<SettingsNavItem
icon={KeyRoundIcon}
label="Providers"