mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): actually wire the chat template allowlist into tools (#23626)
Problem: previously, the deployment-wide chat template allowlist was never actually wired in from `chatd.go`
- Extracts `parseChatTemplateAllowlist` into shared `coderd/util/xjson.ParseUUIDList`
- Adds `Server.chatTemplateAllowlist()` method that reads the allowlist from DB
- Passes `AllowedTemplateIDs` callback to `ListTemplates`, `ReadTemplate`, and `CreateWorkspace` tool constructors
> 🤖 Created by Coder Agents and reviewed by a human.
This commit is contained in:
+6
-19
@@ -44,6 +44,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/searchquery"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/util/xjson"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
@@ -2870,7 +2871,7 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
|
||||
})
|
||||
return
|
||||
}
|
||||
ids, parseErr := parseChatTemplateAllowlist(raw)
|
||||
parsed, parseErr := xjson.ParseUUIDList(raw)
|
||||
if parseErr != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Stored template allowlist is corrupt.",
|
||||
@@ -2878,6 +2879,10 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
|
||||
})
|
||||
return
|
||||
}
|
||||
ids := make([]string, len(parsed))
|
||||
for i, id := range parsed {
|
||||
ids[i] = id.String()
|
||||
}
|
||||
resp := codersdk.ChatTemplateAllowlist{
|
||||
TemplateIDs: ids,
|
||||
}
|
||||
@@ -2983,24 +2988,6 @@ func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// parseChatTemplateAllowlist parses the raw JSON string from the
|
||||
// database into a list of template ID strings. Returns an empty
|
||||
// slice when the value is empty. Returns an error when the stored
|
||||
// JSON is corrupt or otherwise cannot be unmarshalled.
|
||||
func parseChatTemplateAllowlist(raw string) ([]string, error) {
|
||||
if raw == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
var ids []string
|
||||
if err := json.Unmarshal([]byte(raw), &ids); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal template allowlist: %w", err)
|
||||
}
|
||||
if ids == nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package xjson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ParseUUIDList parses a JSON-encoded array of UUID strings
|
||||
// (e.g. `["uuid1","uuid2"]`) and returns the corresponding
|
||||
// slice of uuid.UUID values. An empty input (including
|
||||
// whitespace-only) returns an empty (non-nil) slice.
|
||||
func ParseUUIDList(raw string) ([]uuid.UUID, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return []uuid.UUID{}, nil
|
||||
}
|
||||
|
||||
var strs []string
|
||||
if err := json.Unmarshal([]byte(raw), &strs); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal uuid list: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]uuid.UUID, 0, len(strs))
|
||||
for _, s := range strs {
|
||||
id, err := uuid.Parse(s)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse uuid %q: %w", s, err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package xjson_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/xjson"
|
||||
)
|
||||
|
||||
func TestParseUUIDList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")
|
||||
b := uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []uuid.UUID
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
want: []uuid.UUID{},
|
||||
},
|
||||
{
|
||||
name: "JSONNull",
|
||||
input: "null",
|
||||
want: []uuid.UUID{},
|
||||
},
|
||||
{
|
||||
name: "WhitespaceOnly",
|
||||
input: " \n\t ",
|
||||
want: []uuid.UUID{},
|
||||
},
|
||||
{
|
||||
name: "ValidUUIDs",
|
||||
input: `["c7c6686d-a93c-4df2-bef9-5f837e9a33d5","8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818"]`,
|
||||
want: []uuid.UUID{a, b},
|
||||
},
|
||||
{
|
||||
name: "InvalidJSON",
|
||||
input: "not json at all",
|
||||
wantErr: "unmarshal uuid list",
|
||||
},
|
||||
{
|
||||
name: "InvalidUUID",
|
||||
input: `["not-a-uuid"]`,
|
||||
wantErr: "parse uuid",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := xjson.ParseUUIDList(tt.input)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
+38
-4
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/util/xjson"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
@@ -121,6 +122,36 @@ type Server struct {
|
||||
chatHeartbeatInterval time.Duration
|
||||
}
|
||||
|
||||
// chatTemplateAllowlist returns the deployment-wide template
|
||||
// allowlist as a set of permitted template IDs. The callback
|
||||
// signature matches what the chat tools expect. When the
|
||||
// allowlist is empty or cannot be loaded the function returns
|
||||
// nil, which the tools interpret as "all templates allowed".
|
||||
func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for reading deployment config.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped read
|
||||
// access to deployment config (the template allowlist).
|
||||
ctx = dbauthz.AsChatd(ctx)
|
||||
raw, err := p.db.GetChatTemplateAllowlist(ctx)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to load chat template allowlist", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
ids, err := xjson.ParseUUIDList(raw)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse chat template allowlist", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
m := make(map[uuid.UUID]bool, len(ids))
|
||||
for _, id := range ids {
|
||||
m[id] = true
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type turnWorkspaceContext struct {
|
||||
server *Server
|
||||
chatStateMu *sync.Mutex
|
||||
@@ -3413,12 +3444,14 @@ func (p *Server) runChat(
|
||||
// Workspace provisioning tools.
|
||||
tools = append(tools,
|
||||
chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: p.db,
|
||||
@@ -3429,6 +3462,7 @@ func (p *Server) runChat(
|
||||
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
|
||||
WorkspaceMu: &workspaceMu,
|
||||
Logger: p.logger,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: p.db,
|
||||
|
||||
@@ -3685,3 +3685,116 @@ func TestMCPServerToolInvocation(t *testing.T) {
|
||||
require.True(t, foundToolMessage,
|
||||
"MCP tool result should be persisted as a tool message in the database")
|
||||
}
|
||||
|
||||
func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
|
||||
// Set up a mock OpenAI server. The first streaming call triggers
|
||||
// list_templates; subsequent calls respond with text.
|
||||
var callCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if callCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("list_templates", `{}`),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Here are the templates.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
|
||||
// Create two templates the user can see.
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
tplAllowed := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "allowed-template",
|
||||
})
|
||||
tplBlocked := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "blocked-template",
|
||||
})
|
||||
|
||||
// Set the allowlist to only tplAllowed.
|
||||
allowlistJSON, err := json.Marshal([]string{tplAllowed.ID.String()})
|
||||
require.NoError(t, err)
|
||||
err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON))
|
||||
require.NoError(t, err)
|
||||
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "allowlist-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("List templates"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to finish processing.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// Find the list_templates tool result in the persisted messages.
|
||||
var toolResult string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.Role != database.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
|
||||
part.ToolName == "list_templates" {
|
||||
toolResult = string(part.Result)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
require.NotEmpty(t, toolResult, "list_templates tool result should be persisted")
|
||||
|
||||
// The result should contain only the allowed template.
|
||||
require.Contains(t, toolResult, tplAllowed.ID.String(),
|
||||
"allowed template should appear in list_templates result")
|
||||
require.NotContains(t, toolResult, tplBlocked.ID.String(),
|
||||
"blocked template should NOT appear in list_templates result")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user