fix(coderd): actually wire the chat template allowlist into tools (#23626)

Problem: previously, the deployment-wide chat template allowlist was never actually wired in from `chatd.go`

- Extracts `parseChatTemplateAllowlist` into shared `coderd/util/xjson.ParseUUIDList`
- Adds `Server.chatTemplateAllowlist()` method that reads the allowlist from DB
- Passes `AllowedTemplateIDs` callback to `ListTemplates`, `ReadTemplate`, and `CreateWorkspace` tool constructors

> 🤖 Created by Coder Agents and reviewed by a human.
This commit is contained in:
Cian Johnston
2026-03-25 22:15:27 +00:00
committed by GitHub
parent dab4e6f0a4
commit 7a9d57cd87
5 changed files with 262 additions and 23 deletions
+6 -19
View File
@@ -44,6 +44,7 @@ import (
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/util/xjson"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
@@ -2870,7 +2871,7 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
})
return
}
ids, parseErr := parseChatTemplateAllowlist(raw)
parsed, parseErr := xjson.ParseUUIDList(raw)
if parseErr != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Stored template allowlist is corrupt.",
@@ -2878,6 +2879,10 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
})
return
}
ids := make([]string, len(parsed))
for i, id := range parsed {
ids[i] = id.String()
}
resp := codersdk.ChatTemplateAllowlist{
TemplateIDs: ids,
}
@@ -2983,24 +2988,6 @@ func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
rw.WriteHeader(http.StatusNoContent)
}
// parseChatTemplateAllowlist parses the raw JSON string from the
// database into a list of template ID strings. Returns an empty
// slice when the value is empty. Returns an error when the stored
// JSON is corrupt or otherwise cannot be unmarshalled.
func parseChatTemplateAllowlist(raw string) ([]string, error) {
if raw == "" {
return []string{}, nil
}
var ids []string
if err := json.Unmarshal([]byte(raw), &ids); err != nil {
return nil, xerrors.Errorf("unmarshal template allowlist: %w", err)
}
if ids == nil {
return []string{}, nil
}
return ids, nil
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
+35
View File
@@ -0,0 +1,35 @@
package xjson
import (
"encoding/json"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// ParseUUIDList parses a JSON-encoded array of UUID strings
// (e.g. `["uuid1","uuid2"]`) and returns the corresponding
// slice of uuid.UUID values. An empty input (including
// whitespace-only) returns an empty (non-nil) slice.
func ParseUUIDList(raw string) ([]uuid.UUID, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return []uuid.UUID{}, nil
}
var strs []string
if err := json.Unmarshal([]byte(raw), &strs); err != nil {
return nil, xerrors.Errorf("unmarshal uuid list: %w", err)
}
ids := make([]uuid.UUID, 0, len(strs))
for _, s := range strs {
id, err := uuid.Parse(s)
if err != nil {
return nil, xerrors.Errorf("parse uuid %q: %w", s, err)
}
ids = append(ids, id)
}
return ids, nil
}
+70
View File
@@ -0,0 +1,70 @@
package xjson_test
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/util/xjson"
)
func TestParseUUIDList(t *testing.T) {
t.Parallel()
a := uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")
b := uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")
tests := []struct {
name string
input string
want []uuid.UUID
wantErr string
}{
{
name: "EmptyString",
input: "",
want: []uuid.UUID{},
},
{
name: "JSONNull",
input: "null",
want: []uuid.UUID{},
},
{
name: "WhitespaceOnly",
input: " \n\t ",
want: []uuid.UUID{},
},
{
name: "ValidUUIDs",
input: `["c7c6686d-a93c-4df2-bef9-5f837e9a33d5","8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818"]`,
want: []uuid.UUID{a, b},
},
{
name: "InvalidJSON",
input: "not json at all",
wantErr: "unmarshal uuid list",
},
{
name: "InvalidUUID",
input: `["not-a-uuid"]`,
wantErr: "parse uuid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := xjson.ParseUUIDList(tt.input)
if tt.wantErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, tt.want, got)
})
}
}
+38 -4
View File
@@ -28,6 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/util/xjson"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
@@ -121,6 +122,36 @@ type Server struct {
chatHeartbeatInterval time.Duration
}
// chatTemplateAllowlist returns the deployment-wide template
// allowlist as a set of permitted template IDs. The callback
// signature matches what the chat tools expect. When the
// allowlist is empty or cannot be loaded the function returns
// nil, which the tools interpret as "all templates allowed".
func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for reading deployment config.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
//nolint:gocritic // AsChatd provides narrowly-scoped read
// access to deployment config (the template allowlist).
ctx = dbauthz.AsChatd(ctx)
raw, err := p.db.GetChatTemplateAllowlist(ctx)
if err != nil {
p.logger.Warn(ctx, "failed to load chat template allowlist", slog.Error(err))
return nil
}
ids, err := xjson.ParseUUIDList(raw)
if err != nil {
p.logger.Warn(ctx, "failed to parse chat template allowlist", slog.Error(err))
return nil
}
m := make(map[uuid.UUID]bool, len(ids))
for _, id := range ids {
m[id] = true
}
return m
}
type turnWorkspaceContext struct {
server *Server
chatStateMu *sync.Mutex
@@ -3413,12 +3444,14 @@ func (p *Server) runChat(
// Workspace provisioning tools.
tools = append(tools,
chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: p.db,
OwnerID: chat.OwnerID,
DB: p.db,
OwnerID: chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: p.db,
OwnerID: chat.OwnerID,
DB: p.db,
OwnerID: chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: p.db,
@@ -3429,6 +3462,7 @@ func (p *Server) runChat(
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
WorkspaceMu: &workspaceMu,
Logger: p.logger,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: p.db,
+113
View File
@@ -3685,3 +3685,116 @@ func TestMCPServerToolInvocation(t *testing.T) {
require.True(t, foundToolMessage,
"MCP tool result should be persisted as a tool message in the database")
}
func TestChatTemplateAllowlistEnforcement(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, ps := dbtestutil.NewDB(t)
// Set up a mock OpenAI server. The first streaming call triggers
// list_templates; subsequent calls respond with text.
var callCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
if callCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("list_templates", `{}`),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Here are the templates.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Create two templates the user can see.
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
tplAllowed := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "allowed-template",
})
tplBlocked := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "blocked-template",
})
// Set the allowlist to only tplAllowed.
allowlistJSON, err := json.Marshal([]string{tplAllowed.ID.String()})
require.NoError(t, err)
err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON))
require.NoError(t, err)
server := newActiveTestServer(t, db, ps)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "allowlist-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("List templates"),
},
})
require.NoError(t, err)
// Wait for the chat to finish processing.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// Find the list_templates tool result in the persisted messages.
var toolResult string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleTool {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
part.ToolName == "list_templates" {
toolResult = string(part.Result)
return true
}
}
}
return false
}, testutil.IntervalFast)
require.NotEmpty(t, toolResult, "list_templates tool result should be persisted")
// The result should contain only the allowed template.
require.Contains(t, toolResult, tplAllowed.ID.String(),
"allowed template should appear in list_templates result")
require.NotContains(t, toolResult, tplBlocked.ID.String(),
"blocked template should NOT appear in list_templates result")
}