diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 9baa3ef511..79f97f3ffe 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -44,6 +44,7 @@ import ( "github.com/coder/coder/v2/coderd/searchquery" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/util/xjson" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" @@ -2870,7 +2871,7 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request }) return } - ids, parseErr := parseChatTemplateAllowlist(raw) + parsed, parseErr := xjson.ParseUUIDList(raw) if parseErr != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Stored template allowlist is corrupt.", @@ -2878,6 +2879,10 @@ func (api *API) getChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request }) return } + ids := make([]string, len(parsed)) + for i, id := range parsed { + ids[i] = id.String() + } resp := codersdk.ChatTemplateAllowlist{ TemplateIDs: ids, } @@ -2983,24 +2988,6 @@ func (api *API) putChatTemplateAllowlist(rw http.ResponseWriter, r *http.Request rw.WriteHeader(http.StatusNoContent) } -// parseChatTemplateAllowlist parses the raw JSON string from the -// database into a list of template ID strings. Returns an empty -// slice when the value is empty. Returns an error when the stored -// JSON is corrupt or otherwise cannot be unmarshalled. -func parseChatTemplateAllowlist(raw string) ([]string, error) { - if raw == "" { - return []string{}, nil - } - var ids []string - if err := json.Unmarshal([]byte(raw), &ids); err != nil { - return nil, xerrors.Errorf("unmarshal template allowlist: %w", err) - } - if ids == nil { - return []string{}, nil - } - return ids, nil -} - // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. diff --git a/coderd/util/xjson/xjson.go b/coderd/util/xjson/xjson.go new file mode 100644 index 0000000000..9d900e2305 --- /dev/null +++ b/coderd/util/xjson/xjson.go @@ -0,0 +1,35 @@ +package xjson + +import ( + "encoding/json" + "strings" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// ParseUUIDList parses a JSON-encoded array of UUID strings +// (e.g. `["uuid1","uuid2"]`) and returns the corresponding +// slice of uuid.UUID values. An empty input (including +// whitespace-only) returns an empty (non-nil) slice. +func ParseUUIDList(raw string) ([]uuid.UUID, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return []uuid.UUID{}, nil + } + + var strs []string + if err := json.Unmarshal([]byte(raw), &strs); err != nil { + return nil, xerrors.Errorf("unmarshal uuid list: %w", err) + } + + ids := make([]uuid.UUID, 0, len(strs)) + for _, s := range strs { + id, err := uuid.Parse(s) + if err != nil { + return nil, xerrors.Errorf("parse uuid %q: %w", s, err) + } + ids = append(ids, id) + } + return ids, nil +} diff --git a/coderd/util/xjson/xjson_test.go b/coderd/util/xjson/xjson_test.go new file mode 100644 index 0000000000..3a94811729 --- /dev/null +++ b/coderd/util/xjson/xjson_test.go @@ -0,0 +1,70 @@ +package xjson_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/util/xjson" +) + +func TestParseUUIDList(t *testing.T) { + t.Parallel() + + a := uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5") + b := uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818") + + tests := []struct { + name string + input string + want []uuid.UUID + wantErr string + }{ + { + name: "EmptyString", + input: "", + want: []uuid.UUID{}, + }, + { + name: "JSONNull", + input: "null", + want: []uuid.UUID{}, + }, + { + name: "WhitespaceOnly", + input: " \n\t ", + want: []uuid.UUID{}, + }, + { + name: "ValidUUIDs", + input: `["c7c6686d-a93c-4df2-bef9-5f837e9a33d5","8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818"]`, + want: []uuid.UUID{a, b}, + }, + { + name: "InvalidJSON", + input: "not json at all", + wantErr: "unmarshal uuid list", + }, + { + name: "InvalidUUID", + input: `["not-a-uuid"]`, + wantErr: "parse uuid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := xjson.ParseUUIDList(tt.input) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 7046e398a6..710d90db2e 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -28,6 +28,7 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/util/xjson" "github.com/coder/coder/v2/coderd/webpush" "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/coderd/x/chatd/chatcost" @@ -121,6 +122,36 @@ type Server struct { chatHeartbeatInterval time.Duration } +// chatTemplateAllowlist returns the deployment-wide template +// allowlist as a set of permitted template IDs. The callback +// signature matches what the chat tools expect. When the +// allowlist is empty or cannot be loaded the function returns +// nil, which the tools interpret as "all templates allowed". +func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool { + //nolint:gocritic // AsChatd provides narrowly-scoped daemon + // access for reading deployment config. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + //nolint:gocritic // AsChatd provides narrowly-scoped read + // access to deployment config (the template allowlist). + ctx = dbauthz.AsChatd(ctx) + raw, err := p.db.GetChatTemplateAllowlist(ctx) + if err != nil { + p.logger.Warn(ctx, "failed to load chat template allowlist", slog.Error(err)) + return nil + } + ids, err := xjson.ParseUUIDList(raw) + if err != nil { + p.logger.Warn(ctx, "failed to parse chat template allowlist", slog.Error(err)) + return nil + } + m := make(map[uuid.UUID]bool, len(ids)) + for _, id := range ids { + m[id] = true + } + return m +} + type turnWorkspaceContext struct { server *Server chatStateMu *sync.Mutex @@ -3413,12 +3444,14 @@ func (p *Server) runChat( // Workspace provisioning tools. tools = append(tools, chattool.ListTemplates(chattool.ListTemplatesOptions{ - DB: p.db, - OwnerID: chat.OwnerID, + DB: p.db, + OwnerID: chat.OwnerID, + AllowedTemplateIDs: p.chatTemplateAllowlist, }), chattool.ReadTemplate(chattool.ReadTemplateOptions{ - DB: p.db, - OwnerID: chat.OwnerID, + DB: p.db, + OwnerID: chat.OwnerID, + AllowedTemplateIDs: p.chatTemplateAllowlist, }), chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{ DB: p.db, @@ -3429,6 +3462,7 @@ func (p *Server) runChat( AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, WorkspaceMu: &workspaceMu, Logger: p.logger, + AllowedTemplateIDs: p.chatTemplateAllowlist, }), chattool.StartWorkspace(chattool.StartWorkspaceOptions{ DB: p.db, diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 0aa9c3d408..66172177e9 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -3685,3 +3685,116 @@ func TestMCPServerToolInvocation(t *testing.T) { require.True(t, foundToolMessage, "MCP tool result should be persisted as a tool message in the database") } + +func TestChatTemplateAllowlistEnforcement(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + + // Set up a mock OpenAI server. The first streaming call triggers + // list_templates; subsequent calls respond with text. + var callCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if callCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("list_templates", `{}`), + ) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Here are the templates.")..., + ) + }) + + user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + + // Create two templates the user can see. + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + tplAllowed := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + Name: "allowed-template", + }) + tplBlocked := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + Name: "blocked-template", + }) + + // Set the allowlist to only tplAllowed. + allowlistJSON, err := json.Marshal([]string{tplAllowed.ID.String()}) + require.NoError(t, err) + err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON)) + require.NoError(t, err) + + server := newActiveTestServer(t, db, ps) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "allowlist-test", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("List templates"), + }, + }) + require.NoError(t, err) + + // Wait for the chat to finish processing. + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String) + } + + // Find the list_templates tool result in the persisted messages. + var toolResult string + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleTool { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && + part.ToolName == "list_templates" { + toolResult = string(part.Result) + return true + } + } + } + return false + }, testutil.IntervalFast) + + require.NotEmpty(t, toolResult, "list_templates tool result should be persisted") + + // The result should contain only the allowed template. + require.Contains(t, toolResult, tplAllowed.ID.String(), + "allowed template should appear in list_templates result") + require.NotContains(t, toolResult, tplBlocked.ID.String(), + "blocked template should NOT appear in list_templates result") +}