mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add deployment-wide template allowlist for chats (#23262)
- Stores a deployment-wide agents template allowlist in `site_configs` (`agents_template_allowlist`) - Adds `GET/PUT /api/experimental/chats/config/template-allowlist` endpoints - Filters `list_templates`, `read_template`, and `create_workspace` chat tools by allowlist, if defined (empty=all allowed) - Add "Templates" admin settings tab in Agents UI ([what it looks like](https://624de63c6aacee003aa84340-sitjilsyrr.chromatic.com/?path=/story/pages-agentspage-agentsettingspageview--template-allowlist)) > 🤖 This PR was created with the help of Coder Agents, and has been reviewed by my human. 🧑💻
This commit is contained in:
@@ -1186,6 +1186,8 @@ func New(options *Options) *API {
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
|
||||
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
|
||||
})
|
||||
// TODO(cian): place under /api/experimental/chats/config
|
||||
r.Route("/providers", func(r chi.Router) {
|
||||
|
||||
@@ -2674,6 +2674,17 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist requires deployment-config read permission,
|
||||
// unlike the peer getters (GetChatDesktopEnabled, etc.) which only
|
||||
// check actor presence. The allowlist is admin-configuration that
|
||||
// should not be readable by non-admin users via the HTTP API.
|
||||
func (q *querier) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.db.GetChatTemplateAllowlist(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
@@ -6812,6 +6823,13 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
|
||||
return q.db.UpsertChatSystemPrompt(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
|
||||
@@ -656,6 +656,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatTemplateAllowlist(gomock.Any()).Return("", nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatWorkspaceTTL(gomock.Any()).Return("1h", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
@@ -873,6 +877,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatTemplateAllowlist(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatWorkspaceTTL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatWorkspaceTTL(gomock.Any(), "1h").Return(nil).AnyTimes()
|
||||
check.Args("1h").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
|
||||
@@ -1208,6 +1208,14 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatTemplateAllowlist(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatTemplateAllowlist").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatTemplateAllowlist").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||
@@ -4808,6 +4816,14 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatTemplateAllowlist(ctx, templateAllowlist)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatTemplateAllowlist").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatTemplateAllowlist").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
|
||||
@@ -2223,6 +2223,21 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist mocks base method.
|
||||
func (m *MockStore) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatTemplateAllowlist", ctx)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist indicates an expected call of GetChatTemplateAllowlist.
|
||||
func (mr *MockStoreMockRecorder) GetChatTemplateAllowlist(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).GetChatTemplateAllowlist), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9013,6 +9028,20 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatTemplateAllowlist mocks base method.
|
||||
func (m *MockStore) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatTemplateAllowlist", ctx, templateAllowlist)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatTemplateAllowlist indicates an expected call of UpsertChatTemplateAllowlist.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatTemplateAllowlist(ctx, templateAllowlist any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatTemplateAllowlist", reflect.TypeOf((*MockStore)(nil).UpsertChatTemplateAllowlist), ctx, templateAllowlist)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -254,6 +254,9 @@ type sqlcQuerier interface {
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
GetChatTemplateAllowlist(ctx context.Context) (string, error)
|
||||
GetChatUsageLimitConfig(ctx context.Context) (ChatUsageLimitConfig, error)
|
||||
GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (GetChatUsageLimitGroupOverrideRow, error)
|
||||
GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (GetChatUsageLimitUserOverrideRow, error)
|
||||
@@ -933,6 +936,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
|
||||
@@ -17512,6 +17512,20 @@ func (q *sqlQuerier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return chat_system_prompt, err
|
||||
}
|
||||
|
||||
const getChatTemplateAllowlist = `-- name: GetChatTemplateAllowlist :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist
|
||||
`
|
||||
|
||||
// GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
// Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
func (q *sqlQuerier) GetChatTemplateAllowlist(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatTemplateAllowlist)
|
||||
var template_allowlist string
|
||||
err := row.Scan(&template_allowlist)
|
||||
return template_allowlist, err
|
||||
}
|
||||
|
||||
const getChatWorkspaceTTL = `-- name: GetChatWorkspaceTTL :one
|
||||
SELECT
|
||||
COALESCE(
|
||||
@@ -17743,6 +17757,16 @@ func (q *sqlQuerier) UpsertChatSystemPrompt(ctx context.Context, value string) e
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatTemplateAllowlist = `-- name: UpsertChatTemplateAllowlist :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_template_allowlist'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertChatTemplateAllowlist, templateAllowlist)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatWorkspaceTTL = `-- name: UpsertChatWorkspaceTTL :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_workspace_ttl', $1::text)
|
||||
|
||||
@@ -161,6 +161,12 @@ SET value = CASE
|
||||
END
|
||||
WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
|
||||
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
-- Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
-- name: GetChatTemplateAllowlist :one
|
||||
SELECT
|
||||
COALESCE((SELECT value FROM site_configs WHERE key = 'agents_template_allowlist'), '') :: text AS template_allowlist;
|
||||
|
||||
-- name: GetChatWorkspaceTTL :one
|
||||
-- Returns the global TTL for chat workspaces as a Go duration string.
|
||||
-- Returns "0s" (disabled) when no value has been configured.
|
||||
@@ -170,6 +176,10 @@ SELECT
|
||||
'0s'
|
||||
)::text AS workspace_ttl;
|
||||
|
||||
-- name: UpsertChatTemplateAllowlist :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('agents_template_allowlist', @template_allowlist)
|
||||
ON CONFLICT (key) DO UPDATE SET value = @template_allowlist WHERE site_configs.key = 'agents_template_allowlist';
|
||||
|
||||
-- name: UpsertChatWorkspaceTTL :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_workspace_ttl', @workspace_ttl::text)
|
||||
|
||||
@@ -2777,6 +2777,154 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
raw, err := api.Database.GetChatTemplateAllowlist(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat template allowlist.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
ids, parseErr := parseChatTemplateAllowlist(raw)
|
||||
if parseErr != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Stored template allowlist is corrupt.",
|
||||
Detail: parseErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp := codersdk.ChatTemplateAllowlist{
|
||||
TemplateIDs: ids,
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.ChatTemplateAllowlist
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate all entries are valid UUIDs and deduplicate.
|
||||
seen := make(map[string]struct{}, len(req.TemplateIDs))
|
||||
deduped := make([]string, 0, len(req.TemplateIDs))
|
||||
for _, id := range req.TemplateIDs {
|
||||
parsed, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid template ID in allowlist.",
|
||||
Detail: fmt.Sprintf("%q is not a valid UUID.", id),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Canonicalize to lowercase so deduplication is
|
||||
// case-insensitive and stored values are consistent.
|
||||
canonical := parsed.String()
|
||||
if _, ok := seen[canonical]; !ok {
|
||||
seen[canonical] = struct{}{}
|
||||
deduped = append(deduped, canonical)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to UUIDs for the database query.
|
||||
parsedUUIDs := make([]uuid.UUID, len(deduped))
|
||||
for i, s := range deduped {
|
||||
// Already validated above, safe to ignore error.
|
||||
parsedUUIDs[i], _ = uuid.Parse(s)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(deduped)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error encoding template allowlist.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
// Verify all IDs refer to existing, non-deprecated templates
|
||||
// in a single query.
|
||||
if len(parsedUUIDs) > 0 {
|
||||
found, err := tx.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
||||
IDs: parsedUUIDs,
|
||||
Deprecated: sql.NullBool{
|
||||
Bool: false,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch templates: %w", err)
|
||||
}
|
||||
if len(found) != len(parsedUUIDs) {
|
||||
foundSet := make(map[uuid.UUID]struct{}, len(found))
|
||||
for _, t := range found {
|
||||
foundSet[t.ID] = struct{}{}
|
||||
}
|
||||
var missing []string
|
||||
for _, id := range parsedUUIDs {
|
||||
if _, ok := foundSet[id]; !ok {
|
||||
missing = append(missing, id.String())
|
||||
}
|
||||
}
|
||||
return xerrors.Errorf("templates not found or deprecated: %s", strings.Join(missing, ", "))
|
||||
}
|
||||
}
|
||||
return tx.UpsertChatTemplateAllowlist(ctx, string(raw))
|
||||
}, nil)
|
||||
if err != nil {
|
||||
// If the error mentions "not found or deprecated", it's a
|
||||
// validation failure, not an internal error.
|
||||
if strings.Contains(err.Error(), "not found or deprecated") {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "One or more templates not found or deprecated.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating chat template allowlist.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// parseChatTemplateAllowlist parses the raw JSON string from the
|
||||
// database into a list of template ID strings. Returns an empty
|
||||
// slice when the value is empty. Returns an error when the stored
|
||||
// JSON is corrupt or otherwise cannot be unmarshalled.
|
||||
func parseChatTemplateAllowlist(raw string) ([]string, error) {
|
||||
if raw == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
var ids []string
|
||||
if err := json.Unmarshal([]byte(raw), &ids); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal template allowlist: %w", err)
|
||||
}
|
||||
if ids == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
|
||||
@@ -5161,6 +5161,134 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel // Subtests share a single coderdtest instance and run sequentially.
|
||||
func TestChatTemplateAllowlist(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Shared setup: one coderdtest instance with two real templates.
|
||||
// Subtests that need valid template IDs use these.
|
||||
client, store := newChatClientWithDatabase(t)
|
||||
admin := coderdtest.CreateFirstUser(t, client.Client)
|
||||
tmpl1 := dbgen.Template(t, store, database.Template{
|
||||
OrganizationID: admin.OrganizationID,
|
||||
CreatedBy: admin.UserID,
|
||||
})
|
||||
tmpl2 := dbgen.Template(t, store, database.Template{
|
||||
OrganizationID: admin.OrganizationID,
|
||||
CreatedBy: admin.UserID,
|
||||
})
|
||||
deprecatedTmpl := dbgen.Template(t, store, database.Template{
|
||||
OrganizationID: admin.OrganizationID,
|
||||
CreatedBy: admin.UserID,
|
||||
})
|
||||
//nolint:gocritic // Owner context needed to deprecate the template in test setup.
|
||||
ownerRoles, err := rbac.RoleIdentifiers{rbac.RoleOwner()}.Expand()
|
||||
require.NoError(t, err)
|
||||
err = store.UpdateTemplateAccessControlByID(dbauthz.As(context.Background(), rbac.Subject{
|
||||
ID: "owner",
|
||||
Roles: rbac.Roles(ownerRoles),
|
||||
Scope: rbac.ExpandableScope(rbac.ScopeAll),
|
||||
}), database.UpdateTemplateAccessControlByIDParams{
|
||||
ID: deprecatedTmpl.ID,
|
||||
Deprecated: "this template is deprecated",
|
||||
})
|
||||
require.NoError(t, err, "deprecate template")
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
resp, err := client.GetChatTemplateAllowlist(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.TemplateIDs)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("AdminCanSet", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ids := []string{tmpl1.ID.String(), tmpl2.ID.String()}
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: ids})
|
||||
require.NoError(t, err)
|
||||
resp, err := client.GetChatTemplateAllowlist(ctx)
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, ids, resp.TemplateIDs)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("AdminCanClear", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{}})
|
||||
require.NoError(t, err)
|
||||
resp, err := client.GetChatTemplateAllowlist(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.TemplateIDs)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("NonAdminReadFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
_, err := memberClient.GetChatTemplateAllowlist(ctx)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("NonAdminWriteFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
// Uses a random UUID — hits 404 before template validation.
|
||||
err := memberClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("UnauthenticatedFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
anonClient := codersdk.NewExperimentalClient(codersdk.New(client.URL))
|
||||
// Uses a random UUID — hits 401 before template validation.
|
||||
err := anonClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
|
||||
requireSDKError(t, err, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("InvalidUUIDRejected", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{"not-a-uuid"}})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("NonexistentTemplateRejected", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("DeprecatedTemplateRejected", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{
|
||||
TemplateIDs: []string{deprecatedTmpl.ID.String()},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Sequential: subtests share a single coderdtest instance.
|
||||
t.Run("DeduplicatesIDs", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
id := tmpl1.ID.String()
|
||||
err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{
|
||||
TemplateIDs: []string{id, id, id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
resp, err := client.GetChatTemplateAllowlist(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.TemplateIDs, 1)
|
||||
require.Equal(t, id, resp.TemplateIDs[0])
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+50
-4
@@ -3371,6 +3371,45 @@ func (p *Server) runChat(
|
||||
GetWorkspaceConn: workspaceCtx.getWorkspaceConn,
|
||||
}),
|
||||
}
|
||||
// getAllowedTemplateIDs returns the current deployment-wide
|
||||
// template allowlist, re-reading from the database on each call
|
||||
// so that admin changes take effect without restarting the chat.
|
||||
// Returns nil (= all allowed) on errors to fail open.
|
||||
getAllowedTemplateIDs := func() map[uuid.UUID]bool {
|
||||
raw, err := p.db.GetChatTemplateAllowlist(ctx)
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to load template allowlist, all templates will be allowed", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
var ids []string
|
||||
if jsonErr := json.Unmarshal([]byte(raw), &ids); jsonErr != nil {
|
||||
// Note: the API endpoint (GET /template-allowlist) returns
|
||||
// HTTP 500 for corrupt JSON, giving admins visibility into
|
||||
// the problem. The runtime path here deliberately fails open
|
||||
// so that a corrupt allowlist doesn't block all chats.
|
||||
p.logger.Error(ctx, "failed to parse template allowlist JSON, all templates will be allowed",
|
||||
slog.F("raw", raw), slog.Error(jsonErr))
|
||||
return nil
|
||||
}
|
||||
allowlist := make(map[uuid.UUID]bool, len(ids))
|
||||
for _, s := range ids {
|
||||
if id, parseErr := uuid.Parse(s); parseErr == nil {
|
||||
allowlist[id] = true
|
||||
} else {
|
||||
p.logger.Warn(ctx, "ignoring invalid UUID in template allowlist",
|
||||
slog.F("value", s), slog.Error(parseErr))
|
||||
}
|
||||
}
|
||||
if len(ids) > 0 && len(allowlist) == 0 {
|
||||
p.logger.Error(ctx, "all UUIDs in template allowlist were invalid, all templates will be allowed",
|
||||
slog.F("count", len(ids)))
|
||||
return nil
|
||||
}
|
||||
return allowlist
|
||||
}
|
||||
// Only root chats (not delegated subagents) get workspace
|
||||
// provisioning and subagent tools. Child agents must not
|
||||
// create workspaces or spawn further subagents — they should
|
||||
@@ -3379,12 +3418,14 @@ func (p *Server) runChat(
|
||||
// Workspace provisioning tools.
|
||||
tools = append(tools,
|
||||
chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: getAllowedTemplateIDs,
|
||||
}),
|
||||
chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: getAllowedTemplateIDs,
|
||||
}),
|
||||
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: p.db,
|
||||
@@ -3395,7 +3436,12 @@ func (p *Server) runChat(
|
||||
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
|
||||
WorkspaceMu: &workspaceMu,
|
||||
Logger: p.logger,
|
||||
AllowedTemplateIDs: getAllowedTemplateIDs,
|
||||
}),
|
||||
// StartWorkspace intentionally does not enforce the
|
||||
// template allowlist. The allowlist restricts creation
|
||||
// of new workspaces only — existing workspaces can
|
||||
// be restarted regardless of allowlist changes.
|
||||
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// toolResponse builds a fantasy.ToolResponse from a JSON-serializable
|
||||
@@ -31,3 +32,17 @@ func truncateRunes(value string, maxLen int) string {
|
||||
}
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
|
||||
// isTemplateAllowed checks whether a template ID is permitted by the
|
||||
// configured allowlist. A nil function or an empty allowlist means
|
||||
// all templates are allowed.
|
||||
func isTemplateAllowed(getAllowlist func() map[uuid.UUID]bool, id uuid.UUID) bool {
|
||||
if getAllowlist == nil {
|
||||
return true
|
||||
}
|
||||
allowlist := getAllowlist()
|
||||
if len(allowlist) == 0 {
|
||||
return true
|
||||
}
|
||||
return allowlist[id]
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ type CreateWorkspaceOptions struct {
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
WorkspaceMu *sync.Mutex
|
||||
Logger slog.Logger
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
type createWorkspaceArgs struct {
|
||||
@@ -106,6 +107,10 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
), nil
|
||||
}
|
||||
|
||||
if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) {
|
||||
return fantasy.NewTextErrorResponse("template not available for chat workspaces; use list_templates to find allowed templates"), nil
|
||||
}
|
||||
|
||||
// Serialize workspace creation to prevent parallel
|
||||
// tool calls from creating duplicate workspaces.
|
||||
if options.WorkspaceMu != nil {
|
||||
@@ -121,7 +126,6 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
if done {
|
||||
return toolResponse(existing), nil
|
||||
}
|
||||
|
||||
ownerID := options.OwnerID
|
||||
|
||||
// Set up dbauthz context for DB lookups.
|
||||
|
||||
@@ -3,6 +3,8 @@ package chattool
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"maps"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -20,8 +22,9 @@ const listTemplatesPageSize = 10
|
||||
|
||||
// ListTemplatesOptions configures the list_templates tool.
|
||||
type ListTemplatesOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
type listTemplatesArgs struct {
|
||||
@@ -63,6 +66,13 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
filterParams.FuzzyName = query
|
||||
}
|
||||
|
||||
var allowlist map[uuid.UUID]bool
|
||||
if options.AllowedTemplateIDs != nil {
|
||||
allowlist = options.AllowedTemplateIDs()
|
||||
}
|
||||
if len(allowlist) > 0 {
|
||||
filterParams.IDs = slices.Collect(maps.Keys(allowlist))
|
||||
}
|
||||
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single DB and run sequentially.
|
||||
func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
|
||||
t1 := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "template-alpha",
|
||||
})
|
||||
t2 := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "template-beta",
|
||||
})
|
||||
|
||||
t.Run("ListTemplates", func(t *testing.T) {
|
||||
t.Run("NoAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c1", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Len(t, templates, 2)
|
||||
})
|
||||
|
||||
t.Run("EmptyAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{} },
|
||||
})
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c2", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Len(t, templates, 2)
|
||||
})
|
||||
|
||||
t.Run("OneMatch", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
})
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c3", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Len(t, templates, 1)
|
||||
m := templates[0].(map[string]any)
|
||||
require.Equal(t, t1.ID.String(), m["id"].(string))
|
||||
})
|
||||
|
||||
t.Run("NoMatches", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
})
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c4", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Empty(t, templates)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ReadTemplate", func(t *testing.T) {
|
||||
t.Run("Allowed", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
})
|
||||
input := `{"template_id":"` + t1.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c5", Name: "read_template", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError)
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
tmplInfo := result["template"].(map[string]any)
|
||||
require.Equal(t, t1.ID.String(), tmplInfo["id"].(string))
|
||||
})
|
||||
|
||||
t.Run("Disallowed", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
})
|
||||
input := `{"template_id":"` + t2.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c6", Name: "read_template", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.IsError)
|
||||
require.Contains(t, resp.Content, "not found")
|
||||
})
|
||||
|
||||
t.Run("NoAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
input := `{"template_id":"` + t2.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c7", Name: "read_template", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CreateWorkspace", func(t *testing.T) {
|
||||
t.Run("Allowed", func(t *testing.T) {
|
||||
createCalled := false
|
||||
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
createCalled = true
|
||||
return codersdk.Workspace{}, nil
|
||||
},
|
||||
})
|
||||
input := `{"template_id":"` + t1.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8a", Name: "create_workspace", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.True(t, createCalled, "CreateFn should be called for allowed template")
|
||||
// We don't assert resp.IsError here because CreateWorkspace
|
||||
// does additional work (asOwner, workspace lookup) that
|
||||
// depends on full RBAC setup. The key assertion is that
|
||||
// the allowlist gate passed and CreateFn was invoked.
|
||||
_ = resp
|
||||
})
|
||||
|
||||
t.Run("Disallowed", func(t *testing.T) {
|
||||
createCalled := false
|
||||
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
createCalled = true
|
||||
t.Fatal("CreateFn should not be called for blocked template")
|
||||
return codersdk.Workspace{}, nil
|
||||
},
|
||||
})
|
||||
input := `{"template_id":"` + t1.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8", Name: "create_workspace", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.IsError)
|
||||
require.Contains(t, resp.Content, "template not available for chat workspaces")
|
||||
require.False(t, createCalled, "CreateFn should not be called for blocked template")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -14,8 +14,9 @@ import (
|
||||
|
||||
// ReadTemplateOptions configures the read_template tool.
|
||||
type ReadTemplateOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
type readTemplateArgs struct {
|
||||
@@ -48,6 +49,10 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
), nil
|
||||
}
|
||||
|
||||
if !isTemplateAllowed(options.AllowedTemplateIDs, templateID) {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
|
||||
@@ -425,6 +425,13 @@ func ParseChatWorkspaceTTL(s string) (time.Duration, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// ChatTemplateAllowlist is the request and response body for the
|
||||
// chat template allowlist configuration endpoint. An empty list
|
||||
// means all templates are allowed.
|
||||
type ChatTemplateAllowlist struct {
|
||||
TemplateIDs []string `json:"template_ids"`
|
||||
}
|
||||
|
||||
// ChatProviderConfigSource describes how a provider entry is sourced.
|
||||
type ChatProviderConfigSource string
|
||||
|
||||
@@ -1444,6 +1451,33 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist.
|
||||
func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil)
|
||||
if err != nil {
|
||||
return ChatTemplateAllowlist{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatTemplateAllowlist{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatTemplateAllowlist
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateChatTemplateAllowlist updates the deployment-wide chat template allowlist.
|
||||
func (c *ExperimentalClient) UpdateChatTemplateAllowlist(ctx context.Context, req ChatTemplateAllowlist) error {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/template-allowlist", req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUserChatCustomPrompt updates the user's custom chat prompt.
|
||||
func (c *ExperimentalClient) UpdateUserChatCustomPrompt(ctx context.Context, req UserChatCustomPrompt) (UserChatCustomPrompt, error) {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/user-prompt", req)
|
||||
|
||||
@@ -3216,12 +3216,29 @@ class ExperimentalApiMethods {
|
||||
return response.data;
|
||||
};
|
||||
|
||||
getChatTemplateAllowlist =
|
||||
async (): Promise<TypesGen.ChatTemplateAllowlist> => {
|
||||
const response = await this.axios.get<TypesGen.ChatTemplateAllowlist>(
|
||||
"/api/experimental/chats/config/template-allowlist",
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
updateChatWorkspaceTTL = async (
|
||||
req: TypesGen.UpdateChatWorkspaceTTLRequest,
|
||||
): Promise<void> => {
|
||||
await this.axios.put("/api/experimental/chats/config/workspace-ttl", req);
|
||||
};
|
||||
|
||||
updateChatTemplateAllowlist = async (
|
||||
req: TypesGen.ChatTemplateAllowlist,
|
||||
): Promise<void> => {
|
||||
await this.axios.put(
|
||||
"/api/experimental/chats/config/template-allowlist",
|
||||
req,
|
||||
);
|
||||
};
|
||||
|
||||
getUserChatCustomPrompt =
|
||||
async (): Promise<TypesGen.UserChatCustomPrompt> => {
|
||||
const response = await this.axios.get<TypesGen.UserChatCustomPrompt>(
|
||||
|
||||
@@ -439,6 +439,22 @@ export const updateChatWorkspaceTTL = (queryClient: QueryClient) => ({
|
||||
},
|
||||
});
|
||||
|
||||
const chatTemplateAllowlistKey = ["chat-template-allowlist"] as const;
|
||||
|
||||
export const chatTemplateAllowlist = () => ({
|
||||
queryKey: chatTemplateAllowlistKey,
|
||||
queryFn: () => API.experimental.getChatTemplateAllowlist(),
|
||||
});
|
||||
|
||||
export const updateChatTemplateAllowlist = (queryClient: QueryClient) => ({
|
||||
mutationFn: API.experimental.updateChatTemplateAllowlist,
|
||||
onSuccess: async () => {
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: chatTemplateAllowlistKey,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const chatUserCustomPromptKey = ["chat-user-custom-prompt"] as const;
|
||||
|
||||
export const chatUserCustomPrompt = () => ({
|
||||
|
||||
Generated
+10
@@ -1880,6 +1880,16 @@ export interface ChatSystemPrompt {
|
||||
readonly system_prompt: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatTemplateAllowlist is the request and response body for the
|
||||
* chat template allowlist configuration endpoint. An empty list
|
||||
* means all templates are allowed.
|
||||
*/
|
||||
export interface ChatTemplateAllowlist {
|
||||
readonly template_ids: readonly string[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export interface ChatTextPart {
|
||||
readonly type: "text";
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { MockUserOwner } from "testHelpers/entities";
|
||||
import { MockTemplate, MockUserOwner } from "testHelpers/entities";
|
||||
import { withAuthProvider, withDashboardProvider } from "testHelpers/storybook";
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { API } from "api/api";
|
||||
@@ -170,6 +170,30 @@ const meta = {
|
||||
workspace_ttl_ms: 0,
|
||||
});
|
||||
spyOn(API.experimental, "updateChatWorkspaceTTL").mockResolvedValue();
|
||||
spyOn(API.experimental, "getChatTemplateAllowlist").mockResolvedValue({
|
||||
template_ids: [],
|
||||
});
|
||||
spyOn(API.experimental, "updateChatTemplateAllowlist").mockResolvedValue();
|
||||
spyOn(API, "getTemplates").mockResolvedValue([
|
||||
{
|
||||
...MockTemplate,
|
||||
id: "abc-123",
|
||||
name: "docker-dev",
|
||||
display_name: "Docker Development",
|
||||
},
|
||||
{
|
||||
...MockTemplate,
|
||||
id: "def-456",
|
||||
name: "kubernetes-prod",
|
||||
display_name: "Kubernetes Production",
|
||||
},
|
||||
{
|
||||
...MockTemplate,
|
||||
id: "ghi-789",
|
||||
name: "aws-windows",
|
||||
display_name: "AWS Windows Desktop",
|
||||
},
|
||||
]);
|
||||
},
|
||||
} satisfies Meta<typeof AgentSettingsPageView>;
|
||||
|
||||
@@ -837,3 +861,126 @@ export const NoWarningForCleanPrompt: Story = {
|
||||
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
|
||||
},
|
||||
};
|
||||
|
||||
// ── Templates tab stories ──────────────────────────────────────
|
||||
|
||||
const manyTemplates = [
|
||||
{ id: "t-01", name: "docker-dev", display_name: "Docker Development" },
|
||||
{
|
||||
id: "t-02",
|
||||
name: "kubernetes-prod",
|
||||
display_name: "Kubernetes Production",
|
||||
},
|
||||
{ id: "t-03", name: "aws-windows", display_name: "AWS Windows Desktop" },
|
||||
{ id: "t-04", name: "gcp-linux", display_name: "GCP Linux Workspace" },
|
||||
{ id: "t-05", name: "azure-dotnet", display_name: "Azure .NET Environment" },
|
||||
{ id: "t-06", name: "ml-jupyter", display_name: "ML Jupyter Notebook" },
|
||||
{
|
||||
id: "t-07",
|
||||
name: "data-eng-spark",
|
||||
display_name: "Data Engineering (Spark)",
|
||||
},
|
||||
{
|
||||
id: "t-08",
|
||||
name: "frontend-vite",
|
||||
display_name: "Frontend (Vite + React)",
|
||||
},
|
||||
].map((t) => ({ ...MockTemplate, ...t }));
|
||||
|
||||
export const TemplateAllowlist: Story = {
|
||||
args: {
|
||||
activeSection: "templates",
|
||||
canManageChatModelConfigs: true,
|
||||
canSetSystemPrompt: true,
|
||||
},
|
||||
beforeEach: () => {
|
||||
// Track saved allowlist state across mock calls so the
|
||||
// refetch after save returns the updated value.
|
||||
let savedIDs: string[] = [];
|
||||
|
||||
spyOn(API, "getTemplates").mockResolvedValue(manyTemplates);
|
||||
spyOn(API.experimental, "getChatTemplateAllowlist").mockImplementation(
|
||||
async () => ({ template_ids: savedIDs }),
|
||||
);
|
||||
spyOn(API.experimental, "updateChatTemplateAllowlist").mockImplementation(
|
||||
async (req) => {
|
||||
savedIDs = [...req.template_ids];
|
||||
},
|
||||
);
|
||||
},
|
||||
play: async ({ canvasElement, step }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
await step("starts empty", async () => {
|
||||
// Status text confirms no restrictions.
|
||||
await canvas.findByText(/no templates selected/i);
|
||||
// Save is disabled — nothing to save.
|
||||
const saveBtn = await canvas.findByRole("button", { name: "Save" });
|
||||
expect(saveBtn).toBeDisabled();
|
||||
});
|
||||
|
||||
await step("select one template and save", async () => {
|
||||
// Open the combobox.
|
||||
const input = canvas.getByPlaceholderText("Select templates...");
|
||||
await userEvent.click(input);
|
||||
// Pick the first template from the dropdown.
|
||||
await userEvent.click(
|
||||
await canvas.findByRole("option", { name: "Docker Development" }),
|
||||
);
|
||||
// Badge pill should appear and status should update.
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("1 template selected")).toBeInTheDocument();
|
||||
});
|
||||
// Save should now be enabled.
|
||||
const saveBtn = canvas.getByRole("button", { name: "Save" });
|
||||
expect(saveBtn).toBeEnabled();
|
||||
await userEvent.click(saveBtn);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
API.experimental.updateChatTemplateAllowlist,
|
||||
).toHaveBeenCalledWith({ template_ids: ["t-01"] });
|
||||
});
|
||||
});
|
||||
|
||||
await step("add the remaining seven and save", async () => {
|
||||
// Open the combobox again.
|
||||
const input = canvas.getByLabelText("Select allowed templates");
|
||||
await userEvent.click(input);
|
||||
// Select the other seven templates one by one.
|
||||
for (const name of [
|
||||
"Kubernetes Production",
|
||||
"AWS Windows Desktop",
|
||||
"GCP Linux Workspace",
|
||||
"Azure .NET Environment",
|
||||
"ML Jupyter Notebook",
|
||||
"Data Engineering (Spark)",
|
||||
"Frontend (Vite + React)",
|
||||
]) {
|
||||
await userEvent.click(await canvas.findByRole("option", { name }));
|
||||
}
|
||||
// All eight should now be selected.
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByText("8 templates selected")).toBeInTheDocument();
|
||||
});
|
||||
// Save.
|
||||
const saveBtn = canvas.getByRole("button", { name: "Save" });
|
||||
await userEvent.click(saveBtn);
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
API.experimental.updateChatTemplateAllowlist,
|
||||
).toHaveBeenLastCalledWith({
|
||||
template_ids: expect.arrayContaining([
|
||||
"t-01",
|
||||
"t-02",
|
||||
"t-03",
|
||||
"t-04",
|
||||
"t-05",
|
||||
"t-06",
|
||||
"t-07",
|
||||
"t-08",
|
||||
]),
|
||||
});
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
@@ -5,13 +5,16 @@ import {
|
||||
chatDesktopEnabled,
|
||||
chatModelConfigs,
|
||||
chatSystemPrompt,
|
||||
chatTemplateAllowlist,
|
||||
chatUserCustomPrompt,
|
||||
chatWorkspaceTTL,
|
||||
updateChatDesktopEnabled,
|
||||
updateChatSystemPrompt,
|
||||
updateChatTemplateAllowlist,
|
||||
updateChatWorkspaceTTL,
|
||||
updateUserChatCustomPrompt,
|
||||
} from "api/queries/chats";
|
||||
import { templates } from "api/queries/templates";
|
||||
import { user } from "api/queries/users";
|
||||
import type * as TypesGen from "api/typesGenerated";
|
||||
import dayjs from "dayjs";
|
||||
@@ -35,6 +38,10 @@ import { Alert } from "#/components/Alert/Alert";
|
||||
import { AvatarData } from "#/components/Avatar/AvatarData";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import { Link } from "#/components/Link/Link";
|
||||
import {
|
||||
MultiSelectCombobox,
|
||||
type Option,
|
||||
} from "#/components/MultiSelectCombobox/MultiSelectCombobox";
|
||||
import { PaginationAmount } from "#/components/PaginationWidget/PaginationAmount";
|
||||
import { PaginationWidgetBase } from "#/components/PaginationWidget/PaginationWidgetBase";
|
||||
import { SearchField } from "#/components/SearchField/SearchField";
|
||||
@@ -931,7 +938,147 @@ export const AgentSettingsPageView: FC<AgentSettingsPageViewProps> = ({
|
||||
{activeSection === "insights" && canManageChatModelConfigs && (
|
||||
<InsightsContent />
|
||||
)}
|
||||
{activeSection === "templates" && canManageChatModelConfigs && (
|
||||
<TemplateAllowlistSection />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const TemplateAllowlistSection: FC = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
// Fetch all available templates.
|
||||
const templatesQuery = useQuery(templates());
|
||||
|
||||
// Fetch current allowlist.
|
||||
const allowlistQuery = useQuery(chatTemplateAllowlist());
|
||||
|
||||
const {
|
||||
mutate: saveAllowlist,
|
||||
isPending: isSaving,
|
||||
isError: isSaveError,
|
||||
} = useMutation(updateChatTemplateAllowlist(queryClient));
|
||||
|
||||
const [localSelection, setLocalSelection] = useState<Option[] | null>(null);
|
||||
|
||||
// Map all templates to MultiSelectCombobox options.
|
||||
const allOptions: Option[] = (templatesQuery.data ?? []).map((t) => ({
|
||||
value: t.id,
|
||||
label: t.display_name || t.name,
|
||||
icon: t.icon,
|
||||
}));
|
||||
|
||||
// Build a lookup from template ID to Option for resolving server IDs.
|
||||
const optionsByID = new Map(allOptions.map((o) => [o.value, o]));
|
||||
|
||||
// Resolve the server-side allowlist IDs into Option objects.
|
||||
const serverSelection: Option[] = (allowlistQuery.data?.template_ids ?? [])
|
||||
.map((id) => optionsByID.get(id))
|
||||
.filter((o) => o !== undefined);
|
||||
|
||||
const currentSelection = localSelection ?? serverSelection;
|
||||
|
||||
const serverSet = new Set(serverSelection.map((o) => o.value));
|
||||
const isDirty =
|
||||
localSelection !== null &&
|
||||
(localSelection.length !== serverSet.size ||
|
||||
localSelection.some((o) => !serverSet.has(o.value)));
|
||||
|
||||
const handleSave = (event: FormEvent) => {
|
||||
event.preventDefault();
|
||||
if (!isDirty) return;
|
||||
saveAllowlist(
|
||||
{ template_ids: currentSelection.map((o) => o.value) },
|
||||
{ onSuccess: () => setLocalSelection(null) },
|
||||
);
|
||||
};
|
||||
|
||||
const isLoading = templatesQuery.isLoading || allowlistQuery.isLoading;
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<SectionHeader
|
||||
label="Templates"
|
||||
description="Restrict which templates agents can use to create workspaces. When no templates are selected, all templates are available."
|
||||
badge={<AdminBadge />}
|
||||
/>
|
||||
|
||||
{isLoading && (
|
||||
<div
|
||||
role="status"
|
||||
aria-label="Loading templates"
|
||||
className="flex min-h-[120px] items-center justify-center"
|
||||
>
|
||||
<Spinner size="lg" loading className="text-content-secondary" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isLoading && (templatesQuery.error || allowlistQuery.error) && (
|
||||
<div className="flex min-h-[120px] flex-col items-center justify-center gap-4 text-center">
|
||||
<p className="m-0 text-sm text-content-secondary">
|
||||
Failed to load template data.
|
||||
</p>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
type="button"
|
||||
onClick={() => {
|
||||
void templatesQuery.refetch();
|
||||
void allowlistQuery.refetch();
|
||||
}}
|
||||
>
|
||||
Retry
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isLoading && !templatesQuery.error && !allowlistQuery.error && (
|
||||
<form
|
||||
className="space-y-3"
|
||||
onSubmit={(event) => void handleSave(event)}
|
||||
>
|
||||
<MultiSelectCombobox
|
||||
key={serverSelection.map((o) => o.value).join(",")}
|
||||
inputProps={{ "aria-label": "Select allowed templates" }}
|
||||
options={allOptions}
|
||||
defaultOptions={currentSelection}
|
||||
value={currentSelection}
|
||||
onChange={setLocalSelection}
|
||||
placeholder="Select templates..."
|
||||
emptyIndicator={
|
||||
<p className="text-center text-sm text-content-secondary">
|
||||
No templates found.
|
||||
</p>
|
||||
}
|
||||
disabled={isSaving}
|
||||
hidePlaceholderWhenSelected
|
||||
data-testid="template-allowlist-select"
|
||||
/>
|
||||
<p
|
||||
aria-live="polite"
|
||||
role="status"
|
||||
className="m-0 text-xs text-content-secondary"
|
||||
>
|
||||
{currentSelection.length > 0
|
||||
? `${currentSelection.length} template${currentSelection.length !== 1 ? "s" : ""} selected`
|
||||
: "No templates selected \u2014 all templates are available"}
|
||||
</p>
|
||||
|
||||
<div className="flex justify-end">
|
||||
<Button size="sm" type="submit" disabled={isSaving || !isDirty}>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{isSaveError && (
|
||||
<p role="alert" className="m-0 text-xs text-content-destructive">
|
||||
Failed to save template allowlist.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
GitPullRequestClosedIcon,
|
||||
GitPullRequestDraftIcon,
|
||||
KeyRoundIcon,
|
||||
LayoutTemplateIcon,
|
||||
Loader2Icon,
|
||||
PanelLeftCloseIcon,
|
||||
PauseIcon,
|
||||
@@ -987,6 +988,14 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
|
||||
/>
|
||||
{isAdmin && (
|
||||
<>
|
||||
<SettingsNavItem
|
||||
icon={LayoutTemplateIcon}
|
||||
label="Templates"
|
||||
active={sidebarView.section === "templates"}
|
||||
to="/agents/settings/templates"
|
||||
state={location.state}
|
||||
adminOnly
|
||||
/>
|
||||
<SettingsNavItem
|
||||
icon={KeyRoundIcon}
|
||||
label="Providers"
|
||||
|
||||
Reference in New Issue
Block a user