fix: address post-merge review findings for chat org scoping (#24297)

Addresses review findings from #23827 that were added post-merge:

- Persisted attachments now store `organizationId`; mismatched orgs
pruned on restore
- Workspace selection reconciliation: stale IDs from previous orgs
dropped via derived `effectiveWorkspaceId`
- Org picker uses `permittedOrganizations()` for RBAC-aware filtering
- Org picker hidden when user belongs to only one org
- Ref-sync `useEffect` replaced with `useEffectEvent`
- `CreateWorkspace()` and `ListTemplates()` take `organizationID` and
`db` as required function parameters instead of optional struct fields —
compiler enforces them, removes scattered nil guards
- Cross-org template check in `CreateWorkspace` is now unconditional
- `ListTemplates` org-scoping filter now has test coverage
- `setupChatInfra` comment fixed; test helpers use params structs
instead of positional UUIDs
- Enterprise test documents that org admin only sees own chats (handler
hardcodes `OwnerID` — future work needs sidebar UI before lifting that
restriction)

> 🤖
This commit is contained in:
Cian Johnston
2026-04-15 11:39:05 +01:00
committed by GitHub
parent 5812f84e1c
commit 6194bd6f57
15 changed files with 767 additions and 288 deletions
+55 -35
View File
@@ -10316,7 +10316,8 @@ func TestGetPRInsights(t *testing.T) {
}
// setupChatInfra creates a fresh database with a user, chat provider,
// and model config. Returns the store, user ID, and model config ID.
// and model config. Returns the store, user ID, model config ID,
// and org ID.
setupChatInfra := func(t *testing.T) (database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
t.Helper()
store, _ := dbtestutil.NewDB(t)
@@ -10351,13 +10352,20 @@ func TestGetPRInsights(t *testing.T) {
return store, user.ID, mc.ID, org.ID
}
createChat := func(t *testing.T, store database.Store, userID, mcID, orgID uuid.UUID, title string) database.Chat {
type chatParams struct {
Store database.Store
UserID uuid.UUID
ModelConfigID uuid.UUID
OrgID uuid.UUID
}
createChat := func(t *testing.T, p chatParams, title string) database.Chat {
t.Helper()
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
OrganizationID: orgID,
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
OrganizationID: p.OrgID,
Status: database.ChatStatusWaiting,
OwnerID: userID,
LastModelConfigID: mcID,
OwnerID: p.UserID,
LastModelConfigID: p.ModelConfigID,
Title: title,
})
require.NoError(t, err)
@@ -10416,11 +10424,12 @@ func TestGetPRInsights(t *testing.T) {
t.Run("MultipleChatsSamePR_CostSummed", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
chatA := createChat(t, store, userID, mcID, orgID, "chat-A")
chatA := createChat(t, p, "chat-A")
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000) // $5
chatB := createChat(t, store, userID, mcID, orgID, "chat-B")
chatB := createChat(t, p, "chat-B")
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000) // $3
prURL := "https://github.com/org/repo/pull/123"
@@ -10452,12 +10461,13 @@ func TestGetPRInsights(t *testing.T) {
t.Run("DifferentPRs_NoDuplication", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
chatA := createChat(t, store, userID, mcID, orgID, "chat-A")
chatA := createChat(t, p, "chat-A")
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000)
linkPR(t, store, chatA.ID, "https://github.com/org/repo/pull/1", "merged", "feat: A", 50, 10, 2)
chatB := createChat(t, store, userID, mcID, orgID, "chat-B")
chatB := createChat(t, p, "chat-B")
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000)
linkPR(t, store, chatB.ID, "https://github.com/org/repo/pull/2", "open", "feat: B", 80, 30, 4)
@@ -10486,13 +10496,13 @@ func TestGetPRInsights(t *testing.T) {
// createChildChat creates a chat with ParentChatID and RootChatID
// set, simulating a subagent/child chat in a tree.
createChildChat := func(t *testing.T, store database.Store, userID, mcID, orgID, parentID, rootID uuid.UUID, title string) database.Chat {
createChildChat := func(t *testing.T, p chatParams, parentID, rootID uuid.UUID, title string) database.Chat {
t.Helper()
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
OrganizationID: orgID,
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
OrganizationID: p.OrgID,
Status: database.ChatStatusWaiting,
OwnerID: userID,
LastModelConfigID: mcID,
OwnerID: p.UserID,
LastModelConfigID: p.ModelConfigID,
Title: title,
ParentChatID: uuid.NullUUID{UUID: parentID, Valid: true},
RootChatID: uuid.NullUUID{UUID: rootID, Valid: true},
@@ -10504,10 +10514,11 @@ func TestGetPRInsights(t *testing.T) {
t.Run("DuplicatePRUrl_CountedOnce", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
prURL := "https://github.com/org/repo/pull/99"
for i := 0; i < 3; i++ {
chat := createChat(t, store, userID, mcID, orgID, fmt.Sprintf("chat-%d", i))
for i := range 3 {
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
linkPR(t, store, chat.ID, prURL, "merged", "fix: same PR", 40, 10, 3)
}
@@ -10533,18 +10544,19 @@ func TestGetPRInsights(t *testing.T) {
t.Run("ChildChatCostsIncluded", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Parent chat with a $5 cost.
parent := createChat(t, store, userID, mcID, orgID, "parent-chat")
parent := createChat(t, p, "parent-chat")
insertCostMessage(t, store, parent.ID, userID, mcID, 5_000_000)
// Two child chats (subagents) with $2 each. Only the parent
// has a chat_diff_statuses entry, but the children's costs
// should be included via the tree join.
child1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
child1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
insertCostMessage(t, store, child1.ID, userID, mcID, 2_000_000)
child2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
child2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
insertCostMessage(t, store, child2.ID, userID, mcID, 2_000_000)
prURL := "https://github.com/org/repo/pull/42"
@@ -10575,18 +10587,19 @@ func TestGetPRInsights(t *testing.T) {
t.Run("SiblingPRs_NoCrossContamination", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Parent chat with $10 orchestration cost.
parent := createChat(t, store, userID, mcID, orgID, "parent")
parent := createChat(t, p, "parent")
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
// Child C1 ($5) creates PR1.
c1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/10", "merged", "feat: PR1", 50, 10, 2)
// Child C2 ($3) creates PR2.
c2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
insertCostMessage(t, store, c2.ID, userID, mcID, 3_000_000)
linkPR(t, store, c2.ID, "https://github.com/org/repo/pull/11", "open", "feat: PR2", 30, 5, 1)
@@ -10618,22 +10631,23 @@ func TestGetPRInsights(t *testing.T) {
t.Run("ParentAndChildDifferentPRs_NoCrossContamination", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Parent P ($10) creates PR1.
parent := createChat(t, store, userID, mcID, orgID, "parent")
parent := createChat(t, p, "parent")
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
linkPR(t, store, parent.ID, "https://github.com/org/repo/pull/20", "merged", "feat: parent PR", 80, 20, 4)
// Child C1 ($5) has its own PR2. Because C1 has its own
// chat_diff_statuses entry, its cost should NOT be included
// under PR1 — it belongs to PR2 only.
c1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/21", "open", "feat: child PR", 30, 5, 1)
// Child C2 ($2) has NO cds entry — pure subagent.
// Its cost should be included under PR1 (the parent's PR).
c2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
insertCostMessage(t, store, c2.ID, userID, mcID, 2_000_000)
// PR1 cost = parent ($10) + C2 ($2) = $12 (C1 excluded)
@@ -10663,15 +10677,16 @@ func TestGetPRInsights(t *testing.T) {
t.Run("EmptyURLNotCollapsed", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Two chats with empty-string URLs should be treated as
// separate PRs (NULLIF converts '' to NULL, falling back
// to c.id::text).
chatX := createChat(t, store, userID, mcID, orgID, "chat-X")
chatX := createChat(t, p, "chat-X")
insertCostMessage(t, store, chatX.ID, userID, mcID, 4_000_000)
linkPR(t, store, chatX.ID, "", "open", "draft: X", 10, 2, 1)
chatY := createChat(t, store, userID, mcID, orgID, "chat-Y")
chatY := createChat(t, p, "chat-Y")
insertCostMessage(t, store, chatY.ID, userID, mcID, 6_000_000)
linkPR(t, store, chatY.ID, "", "merged", "draft: Y", 20, 5, 2)
@@ -10696,13 +10711,14 @@ func TestGetPRInsights(t *testing.T) {
t.Run("ParentAndChildSameURL_DedupedWithCombinedCost", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Parent P ($10) links to a PR.
parent := createChat(t, store, userID, mcID, orgID, "parent")
parent := createChat(t, p, "parent")
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
// Child C ($5) also links to the same PR URL.
child := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child")
child := createChildChat(t, p, parent.ID, parent.ID, "child")
insertCostMessage(t, store, child.ID, userID, mcID, 5_000_000)
prURL := "https://github.com/org/repo/pull/50"
@@ -10733,10 +10749,11 @@ func TestGetPRInsights(t *testing.T) {
t.Run("ZeroCostChat_StillCounted", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// A chat linked to a PR but with NO chat_messages at all.
// The PR should still appear with zero cost.
chat := createChat(t, store, userID, mcID, orgID, "zero-cost-chat")
chat := createChat(t, p, "zero-cost-chat")
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/60", "open", "feat: no messages", 25, 5, 2)
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
@@ -10777,7 +10794,8 @@ func TestGetPRInsights(t *testing.T) {
})
require.NoError(t, err)
chat := createChat(t, store, userID, emptyDisplayModel.ID, orgID, "chat-empty-display-name")
p := chatParams{Store: store, UserID: userID, ModelConfigID: emptyDisplayModel.ID, OrgID: orgID}
chat := createChat(t, p, "chat-empty-display-name")
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
@@ -10803,14 +10821,15 @@ func TestGetPRInsights(t *testing.T) {
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Merged PR with $5 cost.
chatMerged := createChat(t, store, userID, mcID, orgID, "chat-merged")
chatMerged := createChat(t, p, "chat-merged")
insertCostMessage(t, store, chatMerged.ID, userID, mcID, 5_000_000)
linkPR(t, store, chatMerged.ID, "https://github.com/org/repo/pull/70", "merged", "fix: merged", 40, 10, 2)
// Open PR with $3 cost.
chatOpen := createChat(t, store, userID, mcID, orgID, "chat-open")
chatOpen := createChat(t, p, "chat-open")
insertCostMessage(t, store, chatOpen.ID, userID, mcID, 3_000_000)
linkPR(t, store, chatOpen.ID, "https://github.com/org/repo/pull/71", "open", "feat: open", 20, 5, 1)
@@ -10829,12 +10848,13 @@ func TestGetPRInsights(t *testing.T) {
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
t.Parallel()
store, userID, mcID, orgID := setupChatInfra(t)
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
// Create 25 distinct PRs — more than the old LIMIT 20 — and
// verify all are returned.
const prCount = 25
for i := range prCount {
chat := createChat(t, store, userID, mcID, orgID, fmt.Sprintf("chat-%d", i))
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
linkPR(t, store, chat.ID,
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
+7 -11
View File
@@ -5059,21 +5059,16 @@ func (p *Server) runChat(
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
}
tools = append(tools,
chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: p.db,
OwnerID: chat.OwnerID,
OrganizationID: chat.OrganizationID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: p.db,
chattool.ListTemplates(chat.OrganizationID, p.db, chattool.ListTemplatesOptions{
OwnerID: chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: p.db,
chattool.ReadTemplate(chat.OrganizationID, p.db, chattool.ReadTemplateOptions{
OwnerID: chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.CreateWorkspace(chat.OrganizationID, p.db, chattool.CreateWorkspaceOptions{
OwnerID: chat.OwnerID,
OrganizationID: chat.OrganizationID,
ChatID: chat.ID,
CreateFn: p.createWorkspaceFn,
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
@@ -5083,6 +5078,7 @@ func (p *Server) runChat(
Logger: p.logger,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: p.db,
OwnerID: chat.OwnerID,
+52 -71
View File
@@ -61,9 +61,7 @@ type AgentConnFunc func(
// CreateWorkspaceOptions configures the create_workspace tool.
type CreateWorkspaceOptions struct {
DB database.Store
OwnerID uuid.UUID
OrganizationID uuid.UUID
ChatID uuid.UUID
CreateFn CreateWorkspaceFn
AgentConnFn AgentConnFunc
@@ -85,7 +83,7 @@ type createWorkspaceArgs struct {
// workspace that is building or running, it returns the existing
// workspace instead of creating a new one. A mutex prevents parallel
// calls from creating duplicate workspaces.
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
func CreateWorkspace(organizationID uuid.UUID, db database.Store, options CreateWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"create_workspace",
"Create a new workspace from a template. Requires a "+
@@ -96,6 +94,9 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
"workspace that is building or running, the existing "+
"workspace is returned.",
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if db == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
if options.CreateFn == nil {
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
}
@@ -123,7 +124,7 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
}
// Check for an existing workspace on the chat.
check := options.checkExistingWorkspace(ctx)
check := options.checkExistingWorkspace(ctx, db)
if check.Err != nil {
if check.FailedBuildID != uuid.Nil {
return buildToolResponse(newBuildError(check.Err.Error(), check.FailedBuildID)), nil
@@ -136,50 +137,44 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
ownerID := options.OwnerID
// Set up dbauthz context for DB lookups.
if options.DB != nil {
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
ctx = ownerCtx
ownerCtx, ownerErr := asOwner(ctx, db, ownerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
ctx = ownerCtx
// Verify the template belongs to the same org as the
// chat. Without this check the tool could silently
// bind a cross-org workspace to the chat.
if options.DB != nil && options.OrganizationID != uuid.Nil {
tmpl, tmplErr := options.DB.GetTemplateByID(ctx, templateID)
if tmplErr != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("look up template: %w", tmplErr).Error(),
), nil
}
if tmpl.OrganizationID != options.OrganizationID {
return fantasy.NewTextErrorResponse(
"template belongs to a different organization than this chat; " +
"use list_templates to find templates in the correct organization",
), nil
}
tmpl, tmplErr := db.GetTemplateByID(ctx, templateID)
if tmplErr != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("look up template: %w", tmplErr).Error(),
), nil
}
if tmpl.OrganizationID != organizationID {
return fantasy.NewTextErrorResponse(
"template belongs to a different organization than this chat; " +
"use list_templates to find templates in the correct organization",
), nil
}
var ttlMs *int64
if options.DB != nil {
raw, err := options.DB.GetChatWorkspaceTTL(ctx)
if err != nil {
options.Logger.Error(ctx, "failed to read chat workspace TTL setting, using template default",
slog.Error(err),
raw, err := db.GetChatWorkspaceTTL(ctx)
if err != nil {
options.Logger.Error(ctx, "failed to read chat workspace TTL setting, using template default",
slog.Error(err),
)
} else {
d, parseErr := codersdk.ParseChatWorkspaceTTL(raw)
if parseErr != nil {
options.Logger.Warn(ctx, "invalid chat workspace TTL setting, using template default",
slog.F("raw", raw),
slog.Error(parseErr),
)
} else {
d, parseErr := codersdk.ParseChatWorkspaceTTL(raw)
if parseErr != nil {
options.Logger.Warn(ctx, "invalid chat workspace TTL setting, using template default",
slog.F("raw", raw),
slog.Error(parseErr),
)
} else if d > 0 {
ms := d.Milliseconds()
ttlMs = &ms
}
} else if d > 0 {
ms := d.Milliseconds()
ttlMs = &ms
}
}
@@ -188,21 +183,9 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
TTLMillis: ttlMs,
}
// Resolve workspace name. This does a second
// GetTemplateByID when no name is provided; the first
// is the org-validation check above. Consolidating
// them would couple the security gate to the
// name-fallback path, and the cost is negligible next
// to the workspace build that follows.
name := strings.TrimSpace(args.Name)
if name == "" {
seed := "workspace"
if options.DB != nil {
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
seed = t.Name
}
}
name = generatedWorkspaceName(seed)
name = generatedWorkspaceName(tmpl.Name)
} else if err := codersdk.NameValid(name); err != nil {
name = generatedWorkspaceName(name)
}
@@ -228,8 +211,8 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
// later fails. The checkExistingWorkspace recovery
// path handles failed workspaces by allowing
// re-creation.
if options.DB != nil && options.ChatID != uuid.Nil {
updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
if options.ChatID != uuid.Nil {
updatedChat, err := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
@@ -259,8 +242,8 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
// come online so subsequent tools can use the
// workspace immediately.
buildID := workspace.LatestBuild.ID
if options.DB != nil && buildID != uuid.Nil {
if err := waitForBuild(ctx, options.DB, buildID); err != nil {
if buildID != uuid.Nil {
if err := waitForBuild(ctx, db, buildID); err != nil {
return buildToolResponse(newBuildError(
xerrors.Errorf("workspace build failed: %w", err).Error(),
buildID,
@@ -276,26 +259,24 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
// Select the chat agent so follow-up tools wait on the
// intended workspace agent.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil {
if len(agents) == 0 {
result["agent_status"] = "no_agent"
agents, agentErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil {
if len(agents) == 0 {
result["agent_status"] = "no_agent"
} else {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
result["agent_status"] = "selection_error"
result["agent_error"] = selectErr.Error()
} else {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
result["agent_status"] = "selection_error"
result["agent_error"] = selectErr.Error()
} else {
workspaceAgentID = selected.ID
}
workspaceAgentID = selected.ID
}
}
}
// Wait for the agent to come online and startup scripts to finish.
if workspaceAgentID != uuid.Nil {
agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn)
agentStatus := waitForAgentReady(ctx, db, workspaceAgentID, options.AgentConnFn)
for k, v := range agentStatus {
result[k] = v
}
@@ -327,12 +308,12 @@ type existingWorkspaceResult struct {
// (workspace is dead or missing).
func (o CreateWorkspaceOptions) checkExistingWorkspace(
ctx context.Context,
db database.Store,
) existingWorkspaceResult {
if o.DB == nil || o.ChatID == uuid.Nil {
if o.ChatID == uuid.Nil {
return existingWorkspaceResult{}
}
db := o.DB
chatID := o.ChatID
agentConnFn := o.AgentConnFn
agentInactiveDisconnectTimeout := o.AgentInactiveDisconnectTimeout
+80 -36
View File
@@ -126,6 +126,7 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
orgID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
@@ -142,6 +143,13 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetTemplateByID(gomock.Any(), templateID).
Return(database.Template{
ID: templateID,
OrganizationID: orgID,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
@@ -187,9 +195,9 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
return nil, func() {}, nil
}
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
CreateFn: createFn,
AgentConnFn: agentConnFn,
WorkspaceMu: &sync.Mutex{},
@@ -218,6 +226,7 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
orgID := uuid.New()
chatID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
@@ -235,9 +244,16 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
Groups: []string{},
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetTemplateByID(gomock.Any(), templateID).
Return(database.Template{
ID: templateID,
OrganizationID: orgID,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
db.EXPECT().
GetWorkspaceBuildByID(gomock.Any(), buildID).
Return(database.WorkspaceBuild{
@@ -269,10 +285,10 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
{ID: uuid.New(), Name: "beta-coderd-chat", DisplayOrder: 1},
}, nil)
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
ChatID: chatID,
ChatID: chatID,
CreateFn: func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
return codersdk.Workspace{
ID: workspaceID,
@@ -315,6 +331,7 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
orgID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
@@ -329,6 +346,13 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetTemplateByID(gomock.Any(), templateID).
Return(database.Template{
ID: templateID,
OrganizationID: orgID,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
@@ -362,9 +386,9 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
}, nil
}
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
ChatID: uuid.Nil,
CreateFn: createFn,
WorkspaceMu: &sync.Mutex{},
@@ -426,6 +450,7 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
orgID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
@@ -440,6 +465,13 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetTemplateByID(gomock.Any(), templateID).
Return(database.Template{
ID: templateID,
OrganizationID: orgID,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return(tc.ttlReturn, tc.ttlErr)
@@ -475,9 +507,9 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
}, nil
}
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
ChatID: uuid.Nil,
CreateFn: createFn,
WorkspaceMu: &sync.Mutex{},
@@ -546,11 +578,10 @@ func TestCreateWorkspace_RejectsCrossOrgTemplate(t *testing.T) {
}, nil)
createCalled := false
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
OrganizationID: chatOrgID,
ChatID: chatID,
tool := CreateWorkspace(chatOrgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
ChatID: chatID,
CreateFn: func(context.Context, uuid.UUID, codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
createCalled = true
return codersdk.Workspace{}, nil
@@ -610,8 +641,9 @@ func TestCheckExistingWorkspace_ConnectedAgent(t *testing.T) {
return nil, nil, xerrors.New("unexpected agent dial")
}
options := testCheckExistingWorkspaceOptions(db, chatID, connFn)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, connFn)
check := options.checkExistingWorkspace(context.Background(), db)
require.NoError(t, check.Err)
require.True(t, check.Done)
require.Equal(t, "already_exists", check.Result["status"])
@@ -702,8 +734,9 @@ func TestCheckExistingWorkspace_InProgressBuildReturnsBuildID(t *testing.T) {
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{}, nil)
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, nil)
check := options.checkExistingWorkspace(context.Background(), db)
require.NoError(t, check.Err)
require.True(t, check.Done)
require.Equal(t, false, check.Result["created"])
@@ -784,8 +817,9 @@ func TestCheckExistingWorkspace_InProgressBuildFailureReturnsBuildID(t *testing.
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
}, nil)
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, nil)
check := options.checkExistingWorkspace(context.Background(), db)
require.Error(t, check.Err)
require.Contains(t, check.Err.Error(), "existing workspace build failed")
require.Equal(t, buildID, check.FailedBuildID)
@@ -831,8 +865,9 @@ func TestCheckExistingWorkspace_ConnectingAgentWaits(t *testing.T) {
return nil, func() {}, nil
}
options := testCheckExistingWorkspaceOptions(db, chatID, connFn)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, connFn)
check := options.checkExistingWorkspace(context.Background(), db)
require.NoError(t, check.Err)
require.True(t, check.Done)
require.Equal(t, 1, connectCalls)
@@ -891,8 +926,9 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{tc.agent}, nil)
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, nil)
check := options.checkExistingWorkspace(context.Background(), db)
require.NoError(t, check.Err)
require.False(t, check.Done)
require.Nil(t, check.Result)
@@ -907,6 +943,7 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
orgID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
@@ -921,6 +958,13 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetTemplateByID(gomock.Any(), templateID).
Return(database.Template{
ID: templateID,
OrganizationID: orgID,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
@@ -953,9 +997,9 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
}, nil
}
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
OwnerID: ownerID,
ChatID: uuid.Nil,
CreateFn: createFn,
WorkspaceMu: &sync.Mutex{},
@@ -997,8 +1041,9 @@ func TestCheckExistingWorkspace_StoppedWorkspace(t *testing.T) {
database.WorkspaceTransitionStop,
)
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, nil)
check := options.checkExistingWorkspace(context.Background(), db)
require.True(t, check.Done)
require.NoError(t, check.Err)
require.Equal(t, "stopped", check.Result["status"])
@@ -1029,20 +1074,19 @@ func TestCheckExistingWorkspace_DeletedWorkspace(t *testing.T) {
Deleted: true,
}, nil)
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
check := options.checkExistingWorkspace(context.Background())
options := testCheckExistingWorkspaceOptions(chatID, nil)
check := options.checkExistingWorkspace(context.Background(), db)
require.NoError(t, check.Err)
require.False(t, check.Done, "should allow creation for deleted workspace")
require.Nil(t, check.Result)
}
func testCheckExistingWorkspaceOptions(
db *dbmock.MockStore,
chatID uuid.UUID,
agentConnFn AgentConnFunc,
) CreateWorkspaceOptions {
return CreateWorkspaceOptions{
DB: db,
ChatID: chatID,
AgentConnFn: agentConnFn,
AgentInactiveDisconnectTimeout: 30 * time.Second,
+10 -12
View File
@@ -1,11 +1,11 @@
package chattool
import (
"cmp"
"context"
"database/sql"
"maps"
"slices"
"sort"
"strings"
"charm.land/fantasy"
@@ -22,9 +22,7 @@ const listTemplatesPageSize = 10
// ListTemplatesOptions configures the list_templates tool.
type ListTemplatesOptions struct {
DB database.Store
OwnerID uuid.UUID
OrganizationID uuid.UUID
AllowedTemplateIDs func() map[uuid.UUID]bool
}
@@ -37,7 +35,7 @@ type listTemplatesArgs struct {
// The agent uses this to discover templates before creating a workspace.
// Results are ordered by number of active developers (most popular first)
// and paginated at 10 per page.
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
func ListTemplates(organizationID uuid.UUID, db database.Store, options ListTemplatesOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"list_templates",
"List available workspace templates. Optionally filter by a "+
@@ -46,18 +44,18 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
"Results are ordered by number of active developers (most popular first). "+
"Returns 10 per page. Use the page parameter to paginate through results.",
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
if db == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
ctx, err := asOwner(ctx, db, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
filterParams := database.GetTemplatesWithFilterParams{
Deleted: false,
OrganizationID: options.OrganizationID,
OrganizationID: organizationID,
Deprecated: sql.NullBool{
Bool: false,
Valid: true,
@@ -75,7 +73,7 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
if len(allowlist) > 0 {
filterParams.IDs = slices.Collect(maps.Keys(allowlist))
}
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
templates, err := db.GetTemplatesWithFilter(ctx, filterParams)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
@@ -87,7 +85,8 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
}
ownerCounts := make(map[uuid.UUID]int64)
if len(templateIDs) > 0 {
rows, countErr := options.DB.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIDs)
rows, countErr := db.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIDs)
if countErr == nil {
for _, row := range rows {
ownerCounts[row.TemplateID] = row.UniqueOwnersSum
@@ -96,10 +95,9 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
}
// Sort by active developer count descending.
sort.SliceStable(templates, func(i, j int) bool {
return ownerCounts[templates[i].ID] > ownerCounts[templates[j].ID]
slices.SortStableFunc(templates, func(a, b database.Template) int {
return cmp.Compare(ownerCounts[b.ID], ownerCounts[a.ID])
})
// Paginate.
page := args.Page
if page < 1 {
+120 -18
View File
@@ -17,6 +17,110 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestListTemplates_OrganizationFilter(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
orgA := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: orgA.ID,
})
orgB := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: orgB.ID,
})
tAlpha := dbgen.Template(t, db, database.Template{
OrganizationID: orgA.ID,
CreatedBy: user.ID,
Name: "alpha",
})
tBeta := dbgen.Template(t, db, database.Template{
OrganizationID: orgB.ID,
CreatedBy: user.ID,
Name: "beta",
})
t.Run("ScopedToOrgA", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tool := chattool.ListTemplates(orgA.ID, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "org-a", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
require.False(t, resp.IsError)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Len(t, templates, 1)
m := templates[0].(map[string]any)
require.Equal(t, tAlpha.ID.String(), m["id"].(string))
})
t.Run("NilOrgReturnsBoth", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
// Pass uuid.Nil to skip org filtering.
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "nil-org", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
require.False(t, resp.IsError)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
templates := result["templates"].([]any)
require.Len(t, templates, 2)
})
t.Run("ReadTemplate_CrossOrgRejected", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Tool scoped to orgA, but requesting a template in orgB.
tool := chattool.ReadTemplate(orgA.ID, db, chattool.ReadTemplateOptions{
OwnerID: user.ID,
})
input := `{"template_id":"` + tBeta.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "cross-org", Name: "read_template", Input: input})
require.NoError(t, err)
require.True(t, resp.IsError)
require.Contains(t, resp.Content, "not found")
})
t.Run("ReadTemplate_SameOrgAllowed", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Tool scoped to orgA, requesting a template in orgA.
tool := chattool.ReadTemplate(orgA.ID, db, chattool.ReadTemplateOptions{
OwnerID: user.ID,
})
input := `{"template_id":"` + tAlpha.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "same-org", Name: "read_template", Input: input})
require.NoError(t, err)
require.False(t, resp.IsError)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
tmplInfo := result["template"].(map[string]any)
require.Equal(t, tAlpha.ID.String(), tmplInfo["id"].(string))
})
}
//nolint:tparallel,paralleltest // Subtests share a single DB and run sequentially.
func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Parallel()
@@ -43,10 +147,10 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Run("ListTemplates", func(t *testing.T) {
t.Run("NoAllowlist", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c1", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
@@ -56,11 +160,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
})
t.Run("EmptyAllowlist", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c2", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
@@ -70,11 +174,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
})
t.Run("OneMatch", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c3", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
@@ -86,11 +190,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
})
t.Run("NoMatches", func(t *testing.T) {
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
DB: db,
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c4", Name: "list_templates", Input: "{}"})
require.NoError(t, err)
var result map[string]any
@@ -102,8 +206,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Run("ReadTemplate", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
})
@@ -118,8 +221,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
})
t.Run("Disallowed", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
})
@@ -131,8 +233,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
})
t.Run("NoAllowlist", func(t *testing.T) {
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
DB: db,
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
OwnerID: user.ID,
})
input := `{"template_id":"` + t2.ID.String() + `"}`
@@ -145,15 +246,16 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Run("CreateWorkspace", func(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
createCalled := false
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: db,
tool := chattool.CreateWorkspace(org.ID, db, chattool.CreateWorkspaceOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
createCalled = true
return codersdk.Workspace{}, nil
},
})
input := `{"template_id":"` + t1.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8a", Name: "create_workspace", Input: input})
require.NoError(t, err)
@@ -167,8 +269,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
t.Run("Disallowed", func(t *testing.T) {
createCalled := false
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
DB: db,
tool := chattool.CreateWorkspace(uuid.Nil, db, chattool.CreateWorkspaceOptions{
OwnerID: user.ID,
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
@@ -177,6 +278,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
return codersdk.Workspace{}, nil
},
})
input := `{"template_id":"` + t1.ID.String() + `"}`
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8", Name: "create_workspace", Input: input})
require.NoError(t, err)
+9 -6
View File
@@ -14,7 +14,6 @@ import (
// ReadTemplateOptions configures the read_template tool.
type ReadTemplateOptions struct {
DB database.Store
OwnerID uuid.UUID
AllowedTemplateIDs func() map[uuid.UUID]bool
}
@@ -26,7 +25,7 @@ type readTemplateArgs struct {
// ReadTemplate returns a tool that retrieves details about a specific
// template, including its configurable rich parameters. The agent
// uses this after list_templates and before create_workspace.
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
func ReadTemplate(organizationID uuid.UUID, db database.Store, options ReadTemplateOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"read_template",
"Get details about a workspace template, including its "+
@@ -34,7 +33,7 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
"template with list_templates and before creating a "+
"workspace with create_workspace.",
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.DB == nil {
if db == nil {
return fantasy.NewTextErrorResponse("database is not configured"), nil
}
@@ -53,17 +52,21 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
return fantasy.NewTextErrorResponse("template not found"), nil
}
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
ctx, err = asOwner(ctx, db, options.OwnerID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
template, err := options.DB.GetTemplateByID(ctx, templateID)
template, err := db.GetTemplateByID(ctx, templateID)
if err != nil {
return fantasy.NewTextErrorResponse("template not found"), nil
}
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
if template.OrganizationID != organizationID {
return fantasy.NewTextErrorResponse("template not found"), nil
}
params, err := db.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
+116
View File
@@ -1172,3 +1172,119 @@ func TestCreateChatNonDefaultOrg(t *testing.T) {
}
require.True(t, found, "chat should be visible in list")
}
func TestListChats_OrgAdminOnlySeesOwnChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, firstUser := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: func() *codersdk.DeploymentValues {
v := coderdtest.DeploymentValues(t)
v.Experiments = []string{string(codersdk.ExperimentAgents)}
return v
}(),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureMultipleOrganizations: 1,
},
},
})
expClient := codersdk.NewExperimentalClient(client)
// Set up a chat provider and model config.
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseURL: "https://example.com",
})
require.NoError(t, err)
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
IsDefault: ptr.Ref(true),
ContextLimit: ptr.Ref(int64(1000)),
CompressionThreshold: ptr.Ref(int32(70)),
})
require.NoError(t, err)
// Create a second (non-default) org.
secondOrg := coderdenttest.CreateOrganization(t, client, coderdenttest.CreateOrganizationOptions{})
// Create a regular member with agents access in the second org.
memberClientRaw, member := coderdtest.CreateAnotherUser(
t, client, firstUser.OrganizationID, rbac.RoleAgentsAccess(),
)
_, err = client.PostOrganizationMember(ctx, secondOrg.ID, member.Username)
require.NoError(t, err)
memberExp := codersdk.NewExperimentalClient(memberClientRaw)
// Member creates a chat in the second org.
memberChat, err := memberExp.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: secondOrg.ID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello from member",
},
},
})
require.NoError(t, err)
require.Equal(t, secondOrg.ID, memberChat.OrganizationID)
// Create an org admin in the second org with agents access.
adminClientRaw, _ := coderdtest.CreateAnotherUser(
t, client, firstUser.OrganizationID,
rbac.ScopedRoleOrgAdmin(secondOrg.ID), rbac.RoleAgentsAccess(),
)
adminExp := codersdk.NewExperimentalClient(adminClientRaw)
// Admin creates a chat in the second org.
adminChat, err := adminExp.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: secondOrg.ID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "hello from admin",
},
},
})
require.NoError(t, err)
require.Equal(t, secondOrg.ID, adminChat.OrganizationID)
// Admin lists chats -- should only see their own chat.
// TODO: The handler currently filters by OwnerID (the
// authenticated user), so org admins cannot see other
// users' chats even though RBAC would allow it. If the
// handler gains an owner filter parameter, update this
// test to verify cross-user visibility.
adminChats, err := adminExp.ListChats(ctx, nil)
require.NoError(t, err)
var foundAdmin, foundMember bool
for _, c := range adminChats {
if c.ID == adminChat.ID {
foundAdmin = true
}
if c.ID == memberChat.ID {
foundMember = true
}
}
require.True(t, foundAdmin, "admin should see own chat")
require.False(t, foundMember, "admin should NOT see member chat (OwnerID filter)")
// Positive control: member can list their own chat.
memberChats, err := memberExp.ListChats(ctx, nil)
require.NoError(t, err)
var memberSeeOwn bool
for _, c := range memberChats {
if c.ID == memberChat.ID {
memberSeeOwn = true
}
}
require.True(t, memberSeeOwn, "member should see own chat")
}
@@ -27,6 +27,7 @@ export const Navbar: FC = () => {
featureVisibility.connection_log && permissions.viewAnyConnectionLog;
const canViewAIBridge =
featureVisibility.aibridge && permissions.viewAnyAIBridgeInterception;
const canCreateChat = permissions.createChat;
const uniqueLinks = new Map<string, LinkConfig>();
for (const link of appearance.support_links ?? []) {
@@ -47,6 +48,7 @@ export const Navbar: FC = () => {
canViewAuditLog={canViewAuditLog}
canViewConnectionLog={canViewConnectionLog}
canViewAIBridge={canViewAIBridge}
canCreateChat={canCreateChat}
proxyContextValue={proxyContextValue}
/>
);
@@ -35,6 +35,7 @@ const meta: Meta<typeof NavbarView> = {
canViewDeployment: true,
canViewHealth: true,
canViewOrganizations: true,
canCreateChat: true,
supportLinks: [],
},
decorators: [withDashboardProvider],
@@ -91,6 +92,18 @@ export const ForMember: Story = {
canViewDeployment: false,
canViewHealth: false,
canViewOrganizations: false,
canCreateChat: false,
},
};
export const ForMemberWithAgentsAccess: Story = {
args: {
user: MockUserMember,
canViewAuditLog: false,
canViewDeployment: false,
canViewHealth: false,
canViewOrganizations: false,
canCreateChat: true,
},
};
@@ -36,6 +36,7 @@ interface NavbarViewProps {
canViewConnectionLog: boolean;
canViewHealth: boolean;
canViewAIBridge: boolean;
canCreateChat: boolean;
proxyContextValue?: ProxyContextValue;
}
@@ -57,6 +58,7 @@ export const NavbarView: FC<NavbarViewProps> = ({
canViewAuditLog,
canViewConnectionLog,
canViewAIBridge,
canCreateChat,
proxyContextValue,
}) => {
const isDev = buildInfo ? isDevBuild(buildInfo) : false;
@@ -78,7 +80,7 @@ export const NavbarView: FC<NavbarViewProps> = ({
)}
</NavLink>
<NavItems className="ml-4" user={user} />
<NavItems className="ml-4" user={user} canCreateChat={canCreateChat} />
{isPreRelease && buildInfo?.version && (
<a
@@ -165,9 +167,10 @@ export const NavbarView: FC<NavbarViewProps> = ({
interface NavItemsProps {
className?: string;
user: TypesGen.User;
canCreateChat: boolean;
}
const NavItems: FC<NavItemsProps> = ({ className, user }) => {
const NavItems: FC<NavItemsProps> = ({ className, user, canCreateChat }) => {
const location = useLocation();
return (
@@ -192,7 +195,7 @@ const NavItems: FC<NavItemsProps> = ({ className, user }) => {
Templates
</NavLink>
<TasksNavItem user={user} />
<AgentsNavItem />
<AgentsNavItem canCreateChat={canCreateChat} />
</nav>
);
};
@@ -257,11 +260,12 @@ function idleTasksLabel(count: number) {
return `You have ${count} ${count === 1 ? "task" : "tasks"} waiting for input`;
}
const AgentsNavItem: FC = () => {
const AgentsNavItem: FC<{ canCreateChat: boolean }> = ({ canCreateChat }) => {
const { experiments, buildInfo } = useDashboard();
const canSeeAgents = experiments.includes("agents") || isDevBuild(buildInfo);
const experimentEnabled =
experiments.includes("agents") || isDevBuild(buildInfo);
if (!canSeeAgents) {
if (!experimentEnabled || !canCreateChat) {
return null;
}
+101 -1
View File
@@ -153,7 +153,7 @@ describe("useEmptyStateDraft", () => {
});
expect(localStorage.getItem(emptyInputStorageKey)).toBeNull();
// Simulate error recovery re-enable persistence.
// Simulate error recovery -- re-enable persistence.
act(() => {
result.current.resetDraft();
});
@@ -307,12 +307,14 @@ describe("useFileAttachments persistence", () => {
fileName: string;
fileType: string;
lastModified: number;
organizationId: string;
}> = {},
) => ({
fileId: "file-1",
fileName: "photo.png",
fileType: "image/png",
lastModified: 1000,
organizationId: "org-1",
...overrides,
});
@@ -504,4 +506,102 @@ describe("useFileAttachments persistence", () => {
expect(localStorage.getItem(persistedAttachmentsStorageKey)).toBeNull();
unmount();
});
it("restores only attachments matching current organization", () => {
const entries = [
makePersistedEntry({
fileId: "f1",
fileName: "a.png",
organizationId: "org-1",
}),
makePersistedEntry({
fileId: "f2",
fileName: "b.png",
organizationId: "org-2",
}),
];
localStorage.setItem(
persistedAttachmentsStorageKey,
JSON.stringify(entries),
);
const { result, unmount } = renderFileAttachments();
expect(result.current.attachments).toHaveLength(1);
expect(result.current.attachments[0].name).toBe("a.png");
// localStorage should be pruned to only the matching org.
const stored = JSON.parse(
localStorage.getItem(persistedAttachmentsStorageKey)!,
);
expect(stored).toHaveLength(1);
expect(stored[0].fileId).toBe("f1");
unmount();
});
it("prunes legacy entries without organizationId", () => {
const legacy = {
fileId: "old-file",
fileName: "legacy.png",
fileType: "image/png",
lastModified: 1000,
// No organizationId field -- simulates pre-org-scoping data.
};
localStorage.setItem(
persistedAttachmentsStorageKey,
JSON.stringify([legacy]),
);
const { result, unmount } = renderFileAttachments();
expect(result.current.attachments).toHaveLength(0);
expect(localStorage.getItem(persistedAttachmentsStorageKey)).toBeNull();
unmount();
});
it("skips restoration when organizationId is undefined", () => {
const entry = makePersistedEntry();
localStorage.setItem(
persistedAttachmentsStorageKey,
JSON.stringify([entry]),
);
const { result, unmount } = renderHook(() =>
useFileAttachments(undefined, { persist: true }),
);
// Should not restore -- org not yet known.
expect(result.current.attachments).toHaveLength(0);
// Should NOT prune -- org unknown, so leave storage alone.
expect(localStorage.getItem(persistedAttachmentsStorageKey)).not.toBeNull();
unmount();
});
it("persists organizationId with attachment metadata", async () => {
const { API } = await import("#/api/api");
vi.spyOn(API.experimental, "uploadChatFile").mockResolvedValue({
id: "new-file-id",
});
vi.spyOn(globalThis, "fetch").mockResolvedValue(new Response());
const { result, unmount } = renderFileAttachments();
const file = new File(["hello"], "test.png", { type: "image/png" });
act(() => {
result.current.handleAttach([file]);
});
await vi.waitFor(() => {
const state = result.current.uploadStates.get(file);
expect(state?.status).toBe("uploaded");
});
const stored = JSON.parse(
localStorage.getItem(persistedAttachmentsStorageKey)!,
);
expect(stored).toHaveLength(1);
expect(stored[0].organizationId).toBe("org-1");
unmount();
});
});
@@ -9,6 +9,13 @@ import {
import { withDashboardProvider } from "#/testHelpers/storybook";
import { AgentCreateForm } from "./AgentCreateForm";
// Query key used by permittedOrganizations() in the form.
const permittedOrgsKey = [
"organizations",
"permitted",
{ object: { resource_type: "chat" }, action: "create" },
];
const modelConfigID = "model-config-1";
const modelOptions = [
@@ -231,6 +238,7 @@ export const PreservesAttachmentsOnFailedSend: Story = {
fileName: "photo.png",
fileType: "image/png",
lastModified: 1000,
organizationId: "my-organization-id",
},
]),
);
@@ -319,6 +327,29 @@ export const WithOrganizationPicker: Story = {
parameters: {
showOrganizations: true,
organizations: [MockDefaultOrganization, MockOrganization2],
queries: [
{
key: permittedOrgsKey,
data: [MockDefaultOrganization, MockOrganization2],
},
],
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
// Verify the org picker rendered (component didn't crash).
await waitFor(() => {
expect(
canvas.getByTestId("organization-autocomplete"),
).toBeInTheDocument();
});
// Type into the chat input to trigger re-renders. If the
// permittedOrgs fallback is referentially unstable, this
// causes a render cascade that hits React's update limit.
const input = canvas.getByTestId("chat-message-input");
await userEvent.click(input);
await userEvent.keyboard("hello world");
// The org picker should still be present after typing.
expect(canvas.getByTestId("organization-autocomplete")).toBeInTheDocument();
},
};
@@ -1,7 +1,9 @@
import { type FC, useEffect, useRef, useState } from "react";
import { type FC, useEffect, useEffectEvent, useRef, useState } from "react";
import { useQuery } from "react-query";
import { Link } from "react-router";
import { toast } from "sonner";
import { isApiError } from "#/api/errors";
import { permittedOrganizations } from "#/api/queries/organizations";
import type * as TypesGen from "#/api/typesGenerated";
import { Alert, AlertDescription } from "#/components/Alert/Alert";
import { ErrorAlert } from "#/components/Alert/ErrorAlert";
@@ -193,9 +195,9 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
const stored = localStorage.getItem(selectedWorkspaceIdStorageKey);
if (!stored) return null;
// If workspaces haven't loaded yet, keep the stored value.
// It will be re-validated once the list arrives via
// filteredWorkspaces clearing the selection if stale.
// The stored value is kept optimistically until workspaces
// load. effectiveWorkspaceId (computed after render) drops
// it if it doesn't match the current org's workspaces.
if (workspaceOptions.length === 0) return stored;
// Validate the stored workspace still exists and belongs
@@ -257,12 +259,6 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
lastUsedModelID,
]);
// Keep a mutable ref to selectedWorkspaceId and selectedModel so
// that the onSend callback always sees the latest values without
// the shared input component re-rendering on every change.
const selectedWorkspaceIdRef = useRef(selectedWorkspaceId);
const selectedModelRef = useRef(selectedModel);
const organizationIdRef = useRef(organizationId);
const [userMCPServerIds, setUserMCPServerIds] = useState<string[] | null>(
null,
);
@@ -276,13 +272,6 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
}
return getDefaultMCPSelection(mcpServers ?? []);
})();
const selectedMCPServerIdsRef = useRef(effectiveMCPServerIds);
useEffect(() => {
selectedWorkspaceIdRef.current = selectedWorkspaceId;
selectedModelRef.current = selectedModel;
selectedMCPServerIdsRef.current = effectiveMCPServerIds;
organizationIdRef.current = organizationId;
});
const handleWorkspaceChange = (value: string | null) => {
if (value === null) {
setSelectedWorkspaceId(null);
@@ -298,27 +287,47 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
setUserSelectedModel(value);
};
const handleSend = async (message: string, fileIDs?: string[]) => {
submitDraft();
await onCreateChat({
message,
fileIDs,
workspaceId: selectedWorkspaceIdRef.current ?? undefined,
model: selectedModelRef.current || undefined,
organizationId: organizationIdRef.current,
mcpServerIds:
selectedMCPServerIdsRef.current.length > 0
? [...selectedMCPServerIdsRef.current]
: undefined,
}).catch((err) => {
// Re-enable draft persistence so the user can edit
// and retry after a failed send attempt, then rethrow
// so callers (handleSendWithAttachments) can preserve
// attachments on failure.
resetDraft();
throw err;
});
};
const isForbidden = !canCreateChat;
// Filter workspaces by the selected organization. We use
// client-side filtering of the full "owner:me" fetch rather
// than re-querying with an org filter because it avoids
// extra loading/error states on org change. The full list is
// already small (user's own workspaces) and limit: 0
// guarantees completeness. If workspace counts grow large
// enough to warrant pagination, this should switch to a
// server-side organization:<name> query filter.
const filteredWorkspaces =
showOrganizations && selectedOrg
? workspaceOptions.filter((ws) => ws.organization_id === selectedOrg.id)
: workspaceOptions;
const effectiveWorkspaceId =
selectedWorkspaceId !== null &&
(isWorkspacesLoading ||
filteredWorkspaces.some((ws) => ws.id === selectedWorkspaceId))
? selectedWorkspaceId
: null;
const handleSend = useEffectEvent(
async (message: string, fileIDs?: string[]) => {
submitDraft();
await onCreateChat({
message,
fileIDs,
workspaceId: effectiveWorkspaceId ?? undefined,
model: selectedModel || undefined,
organizationId,
mcpServerIds:
effectiveMCPServerIds.length > 0
? [...effectiveMCPServerIds]
: undefined,
}).catch((err) => {
resetDraft();
throw err;
});
},
);
const {
attachments,
@@ -357,20 +366,42 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
}
};
const isForbidden = !canCreateChat;
const permittedOrgsQuery = useQuery({
...permittedOrganizations({
object: { resource_type: "chat" },
action: "create",
}),
enabled: showOrganizations,
});
const permittedOrgs = permittedOrgsQuery.data ?? organizations;
// Filter workspaces by the selected organization. We use
// client-side filtering of the full "owner:me" fetch rather
// than re-querying with an org filter because it avoids
// extra loading/error states on org change. The full list is
// already small (user's own workspaces) and limit: 0
// guarantees completeness. If workspace counts grow large
// enough to warrant pagination, this should switch to a
// server-side organization:<name> query filter.
const filteredWorkspaces =
showOrganizations && selectedOrg
? workspaceOptions.filter((ws) => ws.organization_id === selectedOrg.id)
: workspaceOptions;
// Reconcile selectedOrg when permission filtering removes it.
// Only pure state setters run during render; side effects
// (localStorage, blob URL cleanup) run in the effect below.
const [prevPermittedOrgs, setPrevPermittedOrgs] = useState(permittedOrgs);
const [orgWasAdjusted, setOrgWasAdjusted] = useState(false);
if (permittedOrgs !== prevPermittedOrgs) {
setPrevPermittedOrgs(permittedOrgs);
if (selectedOrg && !permittedOrgs.some((o) => o.id === selectedOrg.id)) {
setSelectedOrg(permittedOrgs[0] ?? null);
setOrgWasAdjusted(true);
}
}
// Clean up workspace and attachment state after a programmatic
// org change from permission filtering. These calls have side
// effects (localStorage, blob URL revocation) that must not
// run during render.
const onOrgAdjusted = useEffectEvent(() => {
handleWorkspaceChange(null);
resetAttachments();
});
useEffect(() => {
if (orgWasAdjusted) {
setOrgWasAdjusted(false);
onOrgAdjusted();
}
}, [orgWasAdjusted]);
return (
<>
@@ -399,28 +430,33 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
)
) : null}
{workspacesError != null && <ErrorAlert error={workspacesError} />}
{showOrganizations && (
<div className="flex flex-col gap-2">
<Label htmlFor="organization">Organization</Label>
<OrganizationAutocomplete
id="organization"
required
value={selectedOrg}
options={organizations}
onChange={(newOrg) => {
const orgChanged = newOrg?.id !== selectedOrg?.id;
if (orgChanged && attachments.length > 0) {
setPendingOrgChange(newOrg);
return;
}
if (orgChanged) {
handleWorkspaceChange(null);
}
setSelectedOrg(newOrg);
}}
/>
</div>
{permittedOrgsQuery.error != null && (
<ErrorAlert error={permittedOrgsQuery.error} />
)}
{showOrganizations &&
!permittedOrgsQuery.isLoading &&
permittedOrgs.length > 1 && (
<div className="flex flex-col gap-2">
<Label htmlFor="organization">Organization</Label>
<OrganizationAutocomplete
id="organization"
required
value={selectedOrg}
options={permittedOrgs}
onChange={(newOrg) => {
const orgChanged = newOrg?.id !== selectedOrg?.id;
if (orgChanged && attachments.length > 0) {
setPendingOrgChange(newOrg);
return;
}
if (orgChanged) {
handleWorkspaceChange(null);
}
setSelectedOrg(newOrg);
}}
/>
</div>
)}
<AgentChatInput
onSend={handleSendWithAttachments}
placeholder="Ask Coder to build, fix bugs, or explore your project..."
@@ -449,7 +485,7 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
}}
onMCPAuthComplete={onMCPAuthComplete}
workspaceOptions={filteredWorkspaces}
selectedWorkspaceId={selectedWorkspaceId}
selectedWorkspaceId={effectiveWorkspaceId}
onWorkspaceChange={handleWorkspaceChange}
isWorkspaceLoading={isWorkspacesLoading}
/>
@@ -2,7 +2,7 @@ import {
type Dispatch,
type SetStateAction,
useEffect,
useRef,
useEffectEvent,
useState,
} from "react";
import { API } from "#/api/api";
@@ -21,18 +21,33 @@ interface PersistedAttachment {
fileName: string;
fileType: string;
lastModified: number;
organizationId: string;
}
/**
* Restore previously persisted attachments from localStorage.
* Creates synthetic File objects (empty blobs with correct metadata)
* and populates the corresponding Maps so the UI can render them.
*
* Only attachments matching `currentOrgId` are returned. Entries
* belonging to a different organization are pruned from storage.
*/
function restorePersistedAttachments(): {
function restorePersistedAttachments(currentOrgId: string): {
attachments: File[];
uploadStates: Map<File, UploadState>;
previewUrls: Map<File, string>;
} {
// When the org ID is not yet known (e.g. still loading), skip
// restoration entirely so we don't accidentally prune valid
// entries. The initializer only runs once, so the caller must
// ensure the org ID is available before mounting the hook.
if (!currentOrgId) {
return {
attachments: [],
uploadStates: new Map(),
previewUrls: new Map(),
};
}
const stored = localStorage.getItem(persistedAttachmentsStorageKey);
if (!stored) {
return {
@@ -43,11 +58,25 @@ function restorePersistedAttachments(): {
}
try {
const persisted: PersistedAttachment[] = JSON.parse(stored);
const matched = persisted.filter((p) => p.organizationId === currentOrgId);
// Prune entries that don't match the current org.
if (matched.length !== persisted.length) {
if (matched.length > 0) {
localStorage.setItem(
persistedAttachmentsStorageKey,
JSON.stringify(matched),
);
} else {
localStorage.removeItem(persistedAttachmentsStorageKey);
}
}
const attachments: File[] = [];
const uploadStates = new Map<File, UploadState>();
const previewUrls = new Map<File, string>();
for (const p of persisted) {
for (const p of matched) {
if (!p.fileId || !p.fileName) continue;
// Synthetic File used as a Map key only. Its content is
// never read because the existing file_id is reused at
@@ -72,7 +101,11 @@ function restorePersistedAttachments(): {
}
}
function addPersistedAttachment(file: File, fileId: string) {
function addPersistedAttachment(
file: File,
fileId: string,
organizationId: string,
) {
const stored = localStorage.getItem(persistedAttachmentsStorageKey);
let persisted: PersistedAttachment[];
try {
@@ -85,6 +118,7 @@ function addPersistedAttachment(file: File, fileId: string) {
fileName: file.name,
fileType: file.type,
lastModified: file.lastModified,
organizationId,
});
localStorage.setItem(
persistedAttachmentsStorageKey,
@@ -141,7 +175,7 @@ export function useFileAttachments(
// when persistence is enabled. Computed once on first render.
const [restored] = useState(() =>
persist
? restorePersistedAttachments()
? restorePersistedAttachments(organizationId ?? "")
: {
attachments: [] as File[],
uploadStates: new Map<File, UploadState>(),
@@ -157,16 +191,13 @@ export function useFileAttachments(
);
// Revoke blob URLs on unmount to prevent memory leaks.
const previewUrlsRef = useRef(previewUrls);
useEffect(() => {
previewUrlsRef.current = previewUrls;
const revokePreviewUrls = useEffectEvent(() => {
for (const [, url] of previewUrls) {
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
}
});
useEffect(() => {
return () => {
for (const [, url] of previewUrlsRef.current) {
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
}
};
return () => revokePreviewUrls();
}, []);
const startUpload = (file: File) => {
@@ -179,6 +210,10 @@ export function useFileAttachments(
);
return;
}
const shouldPersist = persist && Boolean(organizationId);
const isImage = file.type.startsWith("image/");
setUploadStates((prev) => new Map(prev).set(file, { status: "uploading" }));
void (async () => {
try {
@@ -192,14 +227,14 @@ export function useFileAttachments(
fileId: result.id,
}),
);
if (persist) {
addPersistedAttachment(file, result.id);
if (shouldPersist) {
addPersistedAttachment(file, result.id, organizationId!);
}
// Pre-warm the browser HTTP cache for images so the
// timeline can render them instantly after send. We
// intentionally skip text attachments because the
// composer already has the text content locally.
if (file.type.startsWith("image/")) {
if (isImage) {
void fetch(`/api/experimental/chats/files/${result.id}`);
}
} catch (err: unknown) {
@@ -310,9 +345,7 @@ export function useFileAttachments(
};
const resetAttachments = () => {
for (const [, url] of previewUrlsRef.current) {
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
}
revokePreviewUrls();
setPreviewUrls(new Map());
setTextContents(new Map());
setUploadStates(new Map());