mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: address post-merge review findings for chat org scoping (#24297)
Addresses review findings from #23827 that were added post-merge: - Persisted attachments now store `organizationId`; mismatched orgs pruned on restore - Workspace selection reconciliation: stale IDs from previous orgs dropped via derived `effectiveWorkspaceId` - Org picker uses `permittedOrganizations()` for RBAC-aware filtering - Org picker hidden when user belongs to only one org - Ref-sync `useEffect` replaced with `useEffectEvent` - `CreateWorkspace()` and `ListTemplates()` take `organizationID` and `db` as required function parameters instead of optional struct fields — compiler enforces them, removes scattered nil guards - Cross-org template check in `CreateWorkspace` is now unconditional - `ListTemplates` org-scoping filter now has test coverage - `setupChatInfra` comment fixed; test helpers use params structs instead of positional UUIDs - Enterprise test documents that org admin only sees own chats (handler hardcodes `OwnerID` — future work needs sidebar UI before lifting that restriction) > 🤖
This commit is contained in:
@@ -10316,7 +10316,8 @@ func TestGetPRInsights(t *testing.T) {
|
||||
}
|
||||
|
||||
// setupChatInfra creates a fresh database with a user, chat provider,
|
||||
// and model config. Returns the store, user ID, and model config ID.
|
||||
// and model config. Returns the store, user ID, model config ID,
|
||||
// and org ID.
|
||||
setupChatInfra := func(t *testing.T) (database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
|
||||
t.Helper()
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
@@ -10351,13 +10352,20 @@ func TestGetPRInsights(t *testing.T) {
|
||||
return store, user.ID, mc.ID, org.ID
|
||||
}
|
||||
|
||||
createChat := func(t *testing.T, store database.Store, userID, mcID, orgID uuid.UUID, title string) database.Chat {
|
||||
type chatParams struct {
|
||||
Store database.Store
|
||||
UserID uuid.UUID
|
||||
ModelConfigID uuid.UUID
|
||||
OrgID uuid.UUID
|
||||
}
|
||||
|
||||
createChat := func(t *testing.T, p chatParams, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
OrganizationID: orgID,
|
||||
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
OrganizationID: p.OrgID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
OwnerID: p.UserID,
|
||||
LastModelConfigID: p.ModelConfigID,
|
||||
Title: title,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -10416,11 +10424,12 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("MultipleChatsSamePR_CostSummed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
chatA := createChat(t, store, userID, mcID, orgID, "chat-A")
|
||||
chatA := createChat(t, p, "chat-A")
|
||||
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000) // $5
|
||||
|
||||
chatB := createChat(t, store, userID, mcID, orgID, "chat-B")
|
||||
chatB := createChat(t, p, "chat-B")
|
||||
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000) // $3
|
||||
|
||||
prURL := "https://github.com/org/repo/pull/123"
|
||||
@@ -10452,12 +10461,13 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("DifferentPRs_NoDuplication", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
chatA := createChat(t, store, userID, mcID, orgID, "chat-A")
|
||||
chatA := createChat(t, p, "chat-A")
|
||||
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000)
|
||||
linkPR(t, store, chatA.ID, "https://github.com/org/repo/pull/1", "merged", "feat: A", 50, 10, 2)
|
||||
|
||||
chatB := createChat(t, store, userID, mcID, orgID, "chat-B")
|
||||
chatB := createChat(t, p, "chat-B")
|
||||
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000)
|
||||
linkPR(t, store, chatB.ID, "https://github.com/org/repo/pull/2", "open", "feat: B", 80, 30, 4)
|
||||
|
||||
@@ -10486,13 +10496,13 @@ func TestGetPRInsights(t *testing.T) {
|
||||
|
||||
// createChildChat creates a chat with ParentChatID and RootChatID
|
||||
// set, simulating a subagent/child chat in a tree.
|
||||
createChildChat := func(t *testing.T, store database.Store, userID, mcID, orgID, parentID, rootID uuid.UUID, title string) database.Chat {
|
||||
createChildChat := func(t *testing.T, p chatParams, parentID, rootID uuid.UUID, title string) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
OrganizationID: orgID,
|
||||
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
||||
OrganizationID: p.OrgID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: userID,
|
||||
LastModelConfigID: mcID,
|
||||
OwnerID: p.UserID,
|
||||
LastModelConfigID: p.ModelConfigID,
|
||||
Title: title,
|
||||
ParentChatID: uuid.NullUUID{UUID: parentID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootID, Valid: true},
|
||||
@@ -10504,10 +10514,11 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("DuplicatePRUrl_CountedOnce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
prURL := "https://github.com/org/repo/pull/99"
|
||||
for i := 0; i < 3; i++ {
|
||||
chat := createChat(t, store, userID, mcID, orgID, fmt.Sprintf("chat-%d", i))
|
||||
for i := range 3 {
|
||||
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
||||
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
||||
linkPR(t, store, chat.ID, prURL, "merged", "fix: same PR", 40, 10, 3)
|
||||
}
|
||||
@@ -10533,18 +10544,19 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("ChildChatCostsIncluded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Parent chat with a $5 cost.
|
||||
parent := createChat(t, store, userID, mcID, orgID, "parent-chat")
|
||||
parent := createChat(t, p, "parent-chat")
|
||||
insertCostMessage(t, store, parent.ID, userID, mcID, 5_000_000)
|
||||
|
||||
// Two child chats (subagents) with $2 each. Only the parent
|
||||
// has a chat_diff_statuses entry, but the children's costs
|
||||
// should be included via the tree join.
|
||||
child1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
|
||||
child1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
||||
insertCostMessage(t, store, child1.ID, userID, mcID, 2_000_000)
|
||||
|
||||
child2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
|
||||
child2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
||||
insertCostMessage(t, store, child2.ID, userID, mcID, 2_000_000)
|
||||
|
||||
prURL := "https://github.com/org/repo/pull/42"
|
||||
@@ -10575,18 +10587,19 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("SiblingPRs_NoCrossContamination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Parent chat with $10 orchestration cost.
|
||||
parent := createChat(t, store, userID, mcID, orgID, "parent")
|
||||
parent := createChat(t, p, "parent")
|
||||
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
||||
|
||||
// Child C1 ($5) creates PR1.
|
||||
c1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
|
||||
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
||||
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
||||
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/10", "merged", "feat: PR1", 50, 10, 2)
|
||||
|
||||
// Child C2 ($3) creates PR2.
|
||||
c2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
|
||||
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
||||
insertCostMessage(t, store, c2.ID, userID, mcID, 3_000_000)
|
||||
linkPR(t, store, c2.ID, "https://github.com/org/repo/pull/11", "open", "feat: PR2", 30, 5, 1)
|
||||
|
||||
@@ -10618,22 +10631,23 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("ParentAndChildDifferentPRs_NoCrossContamination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Parent P ($10) creates PR1.
|
||||
parent := createChat(t, store, userID, mcID, orgID, "parent")
|
||||
parent := createChat(t, p, "parent")
|
||||
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
||||
linkPR(t, store, parent.ID, "https://github.com/org/repo/pull/20", "merged", "feat: parent PR", 80, 20, 4)
|
||||
|
||||
// Child C1 ($5) has its own PR2. Because C1 has its own
|
||||
// chat_diff_statuses entry, its cost should NOT be included
|
||||
// under PR1 — it belongs to PR2 only.
|
||||
c1 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-1")
|
||||
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
||||
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
||||
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/21", "open", "feat: child PR", 30, 5, 1)
|
||||
|
||||
// Child C2 ($2) has NO cds entry — pure subagent.
|
||||
// Its cost should be included under PR1 (the parent's PR).
|
||||
c2 := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child-2")
|
||||
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
||||
insertCostMessage(t, store, c2.ID, userID, mcID, 2_000_000)
|
||||
|
||||
// PR1 cost = parent ($10) + C2 ($2) = $12 (C1 excluded)
|
||||
@@ -10663,15 +10677,16 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("EmptyURLNotCollapsed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Two chats with empty-string URLs should be treated as
|
||||
// separate PRs (NULLIF converts '' to NULL, falling back
|
||||
// to c.id::text).
|
||||
chatX := createChat(t, store, userID, mcID, orgID, "chat-X")
|
||||
chatX := createChat(t, p, "chat-X")
|
||||
insertCostMessage(t, store, chatX.ID, userID, mcID, 4_000_000)
|
||||
linkPR(t, store, chatX.ID, "", "open", "draft: X", 10, 2, 1)
|
||||
|
||||
chatY := createChat(t, store, userID, mcID, orgID, "chat-Y")
|
||||
chatY := createChat(t, p, "chat-Y")
|
||||
insertCostMessage(t, store, chatY.ID, userID, mcID, 6_000_000)
|
||||
linkPR(t, store, chatY.ID, "", "merged", "draft: Y", 20, 5, 2)
|
||||
|
||||
@@ -10696,13 +10711,14 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("ParentAndChildSameURL_DedupedWithCombinedCost", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Parent P ($10) links to a PR.
|
||||
parent := createChat(t, store, userID, mcID, orgID, "parent")
|
||||
parent := createChat(t, p, "parent")
|
||||
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
||||
|
||||
// Child C ($5) also links to the same PR URL.
|
||||
child := createChildChat(t, store, userID, mcID, orgID, parent.ID, parent.ID, "child")
|
||||
child := createChildChat(t, p, parent.ID, parent.ID, "child")
|
||||
insertCostMessage(t, store, child.ID, userID, mcID, 5_000_000)
|
||||
|
||||
prURL := "https://github.com/org/repo/pull/50"
|
||||
@@ -10733,10 +10749,11 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("ZeroCostChat_StillCounted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// A chat linked to a PR but with NO chat_messages at all.
|
||||
// The PR should still appear with zero cost.
|
||||
chat := createChat(t, store, userID, mcID, orgID, "zero-cost-chat")
|
||||
chat := createChat(t, p, "zero-cost-chat")
|
||||
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/60", "open", "feat: no messages", 25, 5, 2)
|
||||
|
||||
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
||||
@@ -10777,7 +10794,8 @@ func TestGetPRInsights(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat := createChat(t, store, userID, emptyDisplayModel.ID, orgID, "chat-empty-display-name")
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: emptyDisplayModel.ID, OrgID: orgID}
|
||||
chat := createChat(t, p, "chat-empty-display-name")
|
||||
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
|
||||
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
|
||||
|
||||
@@ -10803,14 +10821,15 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Merged PR with $5 cost.
|
||||
chatMerged := createChat(t, store, userID, mcID, orgID, "chat-merged")
|
||||
chatMerged := createChat(t, p, "chat-merged")
|
||||
insertCostMessage(t, store, chatMerged.ID, userID, mcID, 5_000_000)
|
||||
linkPR(t, store, chatMerged.ID, "https://github.com/org/repo/pull/70", "merged", "fix: merged", 40, 10, 2)
|
||||
|
||||
// Open PR with $3 cost.
|
||||
chatOpen := createChat(t, store, userID, mcID, orgID, "chat-open")
|
||||
chatOpen := createChat(t, p, "chat-open")
|
||||
insertCostMessage(t, store, chatOpen.ID, userID, mcID, 3_000_000)
|
||||
linkPR(t, store, chatOpen.ID, "https://github.com/org/repo/pull/71", "open", "feat: open", 20, 5, 1)
|
||||
|
||||
@@ -10829,12 +10848,13 @@ func TestGetPRInsights(t *testing.T) {
|
||||
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID, orgID := setupChatInfra(t)
|
||||
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
||||
|
||||
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
||||
// verify all are returned.
|
||||
const prCount = 25
|
||||
for i := range prCount {
|
||||
chat := createChat(t, store, userID, mcID, orgID, fmt.Sprintf("chat-%d", i))
|
||||
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
||||
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
||||
linkPR(t, store, chat.ID,
|
||||
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
||||
|
||||
+7
-11
@@ -5059,21 +5059,16 @@ func (p *Server) runChat(
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
tools = append(tools,
|
||||
chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: chat.OrganizationID,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: p.db,
|
||||
chattool.ListTemplates(chat.OrganizationID, p.db, chattool.ListTemplatesOptions{
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: p.db,
|
||||
chattool.ReadTemplate(chat.OrganizationID, p.db, chattool.ReadTemplateOptions{
|
||||
OwnerID: chat.OwnerID,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
chattool.CreateWorkspace(chat.OrganizationID, p.db, chattool.CreateWorkspaceOptions{
|
||||
OwnerID: chat.OwnerID,
|
||||
OrganizationID: chat.OrganizationID,
|
||||
ChatID: chat.ID,
|
||||
CreateFn: p.createWorkspaceFn,
|
||||
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
|
||||
@@ -5083,6 +5078,7 @@ func (p *Server) runChat(
|
||||
Logger: p.logger,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
|
||||
chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: p.db,
|
||||
OwnerID: chat.OwnerID,
|
||||
|
||||
@@ -61,9 +61,7 @@ type AgentConnFunc func(
|
||||
|
||||
// CreateWorkspaceOptions configures the create_workspace tool.
|
||||
type CreateWorkspaceOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
OrganizationID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
CreateFn CreateWorkspaceFn
|
||||
AgentConnFn AgentConnFunc
|
||||
@@ -85,7 +83,7 @@ type createWorkspaceArgs struct {
|
||||
// workspace that is building or running, it returns the existing
|
||||
// workspace instead of creating a new one. A mutex prevents parallel
|
||||
// calls from creating duplicate workspaces.
|
||||
func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
func CreateWorkspace(organizationID uuid.UUID, db database.Store, options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"create_workspace",
|
||||
"Create a new workspace from a template. Requires a "+
|
||||
@@ -96,6 +94,9 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
"workspace that is building or running, the existing "+
|
||||
"workspace is returned.",
|
||||
func(ctx context.Context, args createWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if db == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
if options.CreateFn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace creator is not configured"), nil
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
// Check for an existing workspace on the chat.
|
||||
check := options.checkExistingWorkspace(ctx)
|
||||
check := options.checkExistingWorkspace(ctx, db)
|
||||
if check.Err != nil {
|
||||
if check.FailedBuildID != uuid.Nil {
|
||||
return buildToolResponse(newBuildError(check.Err.Error(), check.FailedBuildID)), nil
|
||||
@@ -136,50 +137,44 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
ownerID := options.OwnerID
|
||||
|
||||
// Set up dbauthz context for DB lookups.
|
||||
if options.DB != nil {
|
||||
ownerCtx, ownerErr := asOwner(ctx, options.DB, ownerID)
|
||||
if ownerErr != nil {
|
||||
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
|
||||
}
|
||||
ctx = ownerCtx
|
||||
ownerCtx, ownerErr := asOwner(ctx, db, ownerID)
|
||||
if ownerErr != nil {
|
||||
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
|
||||
}
|
||||
ctx = ownerCtx
|
||||
|
||||
// Verify the template belongs to the same org as the
|
||||
// chat. Without this check the tool could silently
|
||||
// bind a cross-org workspace to the chat.
|
||||
if options.DB != nil && options.OrganizationID != uuid.Nil {
|
||||
tmpl, tmplErr := options.DB.GetTemplateByID(ctx, templateID)
|
||||
if tmplErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("look up template: %w", tmplErr).Error(),
|
||||
), nil
|
||||
}
|
||||
if tmpl.OrganizationID != options.OrganizationID {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"template belongs to a different organization than this chat; " +
|
||||
"use list_templates to find templates in the correct organization",
|
||||
), nil
|
||||
}
|
||||
tmpl, tmplErr := db.GetTemplateByID(ctx, templateID)
|
||||
if tmplErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("look up template: %w", tmplErr).Error(),
|
||||
), nil
|
||||
}
|
||||
if tmpl.OrganizationID != organizationID {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"template belongs to a different organization than this chat; " +
|
||||
"use list_templates to find templates in the correct organization",
|
||||
), nil
|
||||
}
|
||||
|
||||
var ttlMs *int64
|
||||
if options.DB != nil {
|
||||
raw, err := options.DB.GetChatWorkspaceTTL(ctx)
|
||||
if err != nil {
|
||||
options.Logger.Error(ctx, "failed to read chat workspace TTL setting, using template default",
|
||||
slog.Error(err),
|
||||
raw, err := db.GetChatWorkspaceTTL(ctx)
|
||||
if err != nil {
|
||||
options.Logger.Error(ctx, "failed to read chat workspace TTL setting, using template default",
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
d, parseErr := codersdk.ParseChatWorkspaceTTL(raw)
|
||||
if parseErr != nil {
|
||||
options.Logger.Warn(ctx, "invalid chat workspace TTL setting, using template default",
|
||||
slog.F("raw", raw),
|
||||
slog.Error(parseErr),
|
||||
)
|
||||
} else {
|
||||
d, parseErr := codersdk.ParseChatWorkspaceTTL(raw)
|
||||
if parseErr != nil {
|
||||
options.Logger.Warn(ctx, "invalid chat workspace TTL setting, using template default",
|
||||
slog.F("raw", raw),
|
||||
slog.Error(parseErr),
|
||||
)
|
||||
} else if d > 0 {
|
||||
ms := d.Milliseconds()
|
||||
ttlMs = &ms
|
||||
}
|
||||
} else if d > 0 {
|
||||
ms := d.Milliseconds()
|
||||
ttlMs = &ms
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,21 +183,9 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
TTLMillis: ttlMs,
|
||||
}
|
||||
|
||||
// Resolve workspace name. This does a second
|
||||
// GetTemplateByID when no name is provided; the first
|
||||
// is the org-validation check above. Consolidating
|
||||
// them would couple the security gate to the
|
||||
// name-fallback path, and the cost is negligible next
|
||||
// to the workspace build that follows.
|
||||
name := strings.TrimSpace(args.Name)
|
||||
if name == "" {
|
||||
seed := "workspace"
|
||||
if options.DB != nil {
|
||||
if t, lookupErr := options.DB.GetTemplateByID(ctx, templateID); lookupErr == nil {
|
||||
seed = t.Name
|
||||
}
|
||||
}
|
||||
name = generatedWorkspaceName(seed)
|
||||
name = generatedWorkspaceName(tmpl.Name)
|
||||
} else if err := codersdk.NameValid(name); err != nil {
|
||||
name = generatedWorkspaceName(name)
|
||||
}
|
||||
@@ -228,8 +211,8 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
// later fails. The checkExistingWorkspace recovery
|
||||
// path handles failed workspaces by allowing
|
||||
// re-creation.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
if options.ChatID != uuid.Nil {
|
||||
updatedChat, err := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
ID: options.ChatID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
@@ -259,8 +242,8 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
// come online so subsequent tools can use the
|
||||
// workspace immediately.
|
||||
buildID := workspace.LatestBuild.ID
|
||||
if options.DB != nil && buildID != uuid.Nil {
|
||||
if err := waitForBuild(ctx, options.DB, buildID); err != nil {
|
||||
if buildID != uuid.Nil {
|
||||
if err := waitForBuild(ctx, db, buildID); err != nil {
|
||||
return buildToolResponse(newBuildError(
|
||||
xerrors.Errorf("workspace build failed: %w", err).Error(),
|
||||
buildID,
|
||||
@@ -276,26 +259,24 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
// Select the chat agent so follow-up tools wait on the
|
||||
// intended workspace agent.
|
||||
workspaceAgentID := uuid.Nil
|
||||
if options.DB != nil {
|
||||
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if agentErr == nil {
|
||||
if len(agents) == 0 {
|
||||
result["agent_status"] = "no_agent"
|
||||
agents, agentErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if agentErr == nil {
|
||||
if len(agents) == 0 {
|
||||
result["agent_status"] = "no_agent"
|
||||
} else {
|
||||
selected, selectErr := agentselect.FindChatAgent(agents)
|
||||
if selectErr != nil {
|
||||
result["agent_status"] = "selection_error"
|
||||
result["agent_error"] = selectErr.Error()
|
||||
} else {
|
||||
selected, selectErr := agentselect.FindChatAgent(agents)
|
||||
if selectErr != nil {
|
||||
result["agent_status"] = "selection_error"
|
||||
result["agent_error"] = selectErr.Error()
|
||||
} else {
|
||||
workspaceAgentID = selected.ID
|
||||
}
|
||||
workspaceAgentID = selected.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the agent to come online and startup scripts to finish.
|
||||
if workspaceAgentID != uuid.Nil {
|
||||
agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn)
|
||||
agentStatus := waitForAgentReady(ctx, db, workspaceAgentID, options.AgentConnFn)
|
||||
for k, v := range agentStatus {
|
||||
result[k] = v
|
||||
}
|
||||
@@ -327,12 +308,12 @@ type existingWorkspaceResult struct {
|
||||
// (workspace is dead or missing).
|
||||
func (o CreateWorkspaceOptions) checkExistingWorkspace(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
) existingWorkspaceResult {
|
||||
if o.DB == nil || o.ChatID == uuid.Nil {
|
||||
if o.ChatID == uuid.Nil {
|
||||
return existingWorkspaceResult{}
|
||||
}
|
||||
|
||||
db := o.DB
|
||||
chatID := o.ChatID
|
||||
agentConnFn := o.AgentConnFn
|
||||
agentInactiveDisconnectTimeout := o.AgentInactiveDisconnectTimeout
|
||||
|
||||
@@ -126,6 +126,7 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
@@ -142,6 +143,13 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetTemplateByID(gomock.Any(), templateID).
|
||||
Return(database.Template{
|
||||
ID: templateID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
@@ -187,9 +195,9 @@ func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
|
||||
return nil, func() {}, nil
|
||||
}
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
|
||||
CreateFn: createFn,
|
||||
AgentConnFn: agentConnFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
@@ -218,6 +226,7 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
@@ -235,9 +244,16 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
|
||||
Groups: []string{},
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetTemplateByID(gomock.Any(), templateID).
|
||||
Return(database.Template{
|
||||
ID: templateID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetWorkspaceBuildByID(gomock.Any(), buildID).
|
||||
Return(database.WorkspaceBuild{
|
||||
@@ -269,10 +285,10 @@ func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
|
||||
{ID: uuid.New(), Name: "beta-coderd-chat", DisplayOrder: 1},
|
||||
}, nil)
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
ChatID: chatID,
|
||||
|
||||
ChatID: chatID,
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
@@ -315,6 +331,7 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
@@ -329,6 +346,13 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetTemplateByID(gomock.Any(), templateID).
|
||||
Return(database.Template{
|
||||
ID: templateID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
@@ -362,9 +386,9 @@ func TestCreateWorkspace_PostCreationBuildFailure(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
|
||||
ChatID: uuid.Nil,
|
||||
CreateFn: createFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
@@ -426,6 +450,7 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
@@ -440,6 +465,13 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetTemplateByID(gomock.Any(), templateID).
|
||||
Return(database.Template{
|
||||
ID: templateID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return(tc.ttlReturn, tc.ttlErr)
|
||||
@@ -475,9 +507,9 @@ func TestCreateWorkspace_GlobalTTL(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
|
||||
ChatID: uuid.Nil,
|
||||
CreateFn: createFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
@@ -546,11 +578,10 @@ func TestCreateWorkspace_RejectsCrossOrgTemplate(t *testing.T) {
|
||||
}, nil)
|
||||
|
||||
createCalled := false
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: chatOrgID,
|
||||
ChatID: chatID,
|
||||
tool := CreateWorkspace(chatOrgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
|
||||
ChatID: chatID,
|
||||
CreateFn: func(context.Context, uuid.UUID, codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
createCalled = true
|
||||
return codersdk.Workspace{}, nil
|
||||
@@ -610,8 +641,9 @@ func TestCheckExistingWorkspace_ConnectedAgent(t *testing.T) {
|
||||
return nil, nil, xerrors.New("unexpected agent dial")
|
||||
}
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, connFn)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, connFn)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.NoError(t, check.Err)
|
||||
require.True(t, check.Done)
|
||||
require.Equal(t, "already_exists", check.Result["status"])
|
||||
@@ -702,8 +734,9 @@ func TestCheckExistingWorkspace_InProgressBuildReturnsBuildID(t *testing.T) {
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{}, nil)
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.NoError(t, check.Err)
|
||||
require.True(t, check.Done)
|
||||
require.Equal(t, false, check.Result["created"])
|
||||
@@ -784,8 +817,9 @@ func TestCheckExistingWorkspace_InProgressBuildFailureReturnsBuildID(t *testing.
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
}, nil)
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.Error(t, check.Err)
|
||||
require.Contains(t, check.Err.Error(), "existing workspace build failed")
|
||||
require.Equal(t, buildID, check.FailedBuildID)
|
||||
@@ -831,8 +865,9 @@ func TestCheckExistingWorkspace_ConnectingAgentWaits(t *testing.T) {
|
||||
return nil, func() {}, nil
|
||||
}
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, connFn)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, connFn)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.NoError(t, check.Err)
|
||||
require.True(t, check.Done)
|
||||
require.Equal(t, 1, connectCalls)
|
||||
@@ -891,8 +926,9 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{tc.agent}, nil)
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.NoError(t, check.Err)
|
||||
require.False(t, check.Done)
|
||||
require.Nil(t, check.Result)
|
||||
@@ -907,6 +943,7 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
orgID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
@@ -921,6 +958,13 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetTemplateByID(gomock.Any(), templateID).
|
||||
Return(database.Template{
|
||||
ID: templateID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
@@ -953,9 +997,9 @@ func TestWaitForBuild_CanceledJob(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
tool := CreateWorkspace(orgID, db, CreateWorkspaceOptions{
|
||||
OwnerID: ownerID,
|
||||
|
||||
ChatID: uuid.Nil,
|
||||
CreateFn: createFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
@@ -997,8 +1041,9 @@ func TestCheckExistingWorkspace_StoppedWorkspace(t *testing.T) {
|
||||
database.WorkspaceTransitionStop,
|
||||
)
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.True(t, check.Done)
|
||||
require.NoError(t, check.Err)
|
||||
require.Equal(t, "stopped", check.Result["status"])
|
||||
@@ -1029,20 +1074,19 @@ func TestCheckExistingWorkspace_DeletedWorkspace(t *testing.T) {
|
||||
Deleted: true,
|
||||
}, nil)
|
||||
|
||||
options := testCheckExistingWorkspaceOptions(db, chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background())
|
||||
options := testCheckExistingWorkspaceOptions(chatID, nil)
|
||||
check := options.checkExistingWorkspace(context.Background(), db)
|
||||
|
||||
require.NoError(t, check.Err)
|
||||
require.False(t, check.Done, "should allow creation for deleted workspace")
|
||||
require.Nil(t, check.Result)
|
||||
}
|
||||
|
||||
func testCheckExistingWorkspaceOptions(
|
||||
db *dbmock.MockStore,
|
||||
chatID uuid.UUID,
|
||||
agentConnFn AgentConnFunc,
|
||||
) CreateWorkspaceOptions {
|
||||
return CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
ChatID: chatID,
|
||||
AgentConnFn: agentConnFn,
|
||||
AgentInactiveDisconnectTimeout: 30 * time.Second,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"maps"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -22,9 +22,7 @@ const listTemplatesPageSize = 10
|
||||
|
||||
// ListTemplatesOptions configures the list_templates tool.
|
||||
type ListTemplatesOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
OrganizationID uuid.UUID
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
@@ -37,7 +35,7 @@ type listTemplatesArgs struct {
|
||||
// The agent uses this to discover templates before creating a workspace.
|
||||
// Results are ordered by number of active developers (most popular first)
|
||||
// and paginated at 10 per page.
|
||||
func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
func ListTemplates(organizationID uuid.UUID, db database.Store, options ListTemplatesOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"list_templates",
|
||||
"List available workspace templates. Optionally filter by a "+
|
||||
@@ -46,18 +44,18 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
"Results are ordered by number of active developers (most popular first). "+
|
||||
"Returns 10 per page. Use the page parameter to paginate through results.",
|
||||
func(ctx context.Context, args listTemplatesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
if db == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
ctx, err := asOwner(ctx, options.DB, options.OwnerID)
|
||||
ctx, err := asOwner(ctx, db, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
filterParams := database.GetTemplatesWithFilterParams{
|
||||
Deleted: false,
|
||||
OrganizationID: options.OrganizationID,
|
||||
OrganizationID: organizationID,
|
||||
Deprecated: sql.NullBool{
|
||||
Bool: false,
|
||||
Valid: true,
|
||||
@@ -75,7 +73,7 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
if len(allowlist) > 0 {
|
||||
filterParams.IDs = slices.Collect(maps.Keys(allowlist))
|
||||
}
|
||||
templates, err := options.DB.GetTemplatesWithFilter(ctx, filterParams)
|
||||
templates, err := db.GetTemplatesWithFilter(ctx, filterParams)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -87,7 +85,8 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
}
|
||||
ownerCounts := make(map[uuid.UUID]int64)
|
||||
if len(templateIDs) > 0 {
|
||||
rows, countErr := options.DB.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIDs)
|
||||
rows, countErr := db.GetWorkspaceUniqueOwnerCountByTemplateIDs(ctx, templateIDs)
|
||||
|
||||
if countErr == nil {
|
||||
for _, row := range rows {
|
||||
ownerCounts[row.TemplateID] = row.UniqueOwnersSum
|
||||
@@ -96,10 +95,9 @@ func ListTemplates(options ListTemplatesOptions) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
// Sort by active developer count descending.
|
||||
sort.SliceStable(templates, func(i, j int) bool {
|
||||
return ownerCounts[templates[i].ID] > ownerCounts[templates[j].ID]
|
||||
slices.SortStableFunc(templates, func(a, b database.Template) int {
|
||||
return cmp.Compare(ownerCounts[b.ID], ownerCounts[a.ID])
|
||||
})
|
||||
|
||||
// Paginate.
|
||||
page := args.Page
|
||||
if page < 1 {
|
||||
|
||||
@@ -17,6 +17,110 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestListTemplates_OrganizationFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
orgA := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: orgA.ID,
|
||||
})
|
||||
orgB := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: orgB.ID,
|
||||
})
|
||||
|
||||
tAlpha := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: orgA.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "alpha",
|
||||
})
|
||||
tBeta := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: orgB.ID,
|
||||
CreatedBy: user.ID,
|
||||
Name: "beta",
|
||||
})
|
||||
|
||||
t.Run("ScopedToOrgA", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
tool := chattool.ListTemplates(orgA.ID, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "org-a", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Len(t, templates, 1)
|
||||
m := templates[0].(map[string]any)
|
||||
require.Equal(t, tAlpha.ID.String(), m["id"].(string))
|
||||
})
|
||||
|
||||
t.Run("NilOrgReturnsBoth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
// Pass uuid.Nil to skip org filtering.
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "nil-org", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
templates := result["templates"].([]any)
|
||||
require.Len(t, templates, 2)
|
||||
})
|
||||
|
||||
t.Run("ReadTemplate_CrossOrgRejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Tool scoped to orgA, but requesting a template in orgB.
|
||||
tool := chattool.ReadTemplate(orgA.ID, db, chattool.ReadTemplateOptions{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
|
||||
input := `{"template_id":"` + tBeta.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "cross-org", Name: "read_template", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.IsError)
|
||||
require.Contains(t, resp.Content, "not found")
|
||||
})
|
||||
|
||||
t.Run("ReadTemplate_SameOrgAllowed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Tool scoped to orgA, requesting a template in orgA.
|
||||
tool := chattool.ReadTemplate(orgA.ID, db, chattool.ReadTemplateOptions{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
|
||||
input := `{"template_id":"` + tAlpha.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "same-org", Name: "read_template", Input: input})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
tmplInfo := result["template"].(map[string]any)
|
||||
require.Equal(t, tAlpha.ID.String(), tmplInfo["id"].(string))
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single DB and run sequentially.
|
||||
func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -43,10 +147,10 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
|
||||
t.Run("ListTemplates", func(t *testing.T) {
|
||||
t.Run("NoAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c1", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
@@ -56,11 +160,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("EmptyAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{} },
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c2", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
@@ -70,11 +174,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OneMatch", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c3", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
@@ -86,11 +190,11 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("NoMatches", func(t *testing.T) {
|
||||
tool := chattool.ListTemplates(chattool.ListTemplatesOptions{
|
||||
DB: db,
|
||||
tool := chattool.ListTemplates(uuid.Nil, db, chattool.ListTemplatesOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c4", Name: "list_templates", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
var result map[string]any
|
||||
@@ -102,8 +206,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
|
||||
t.Run("ReadTemplate", func(t *testing.T) {
|
||||
t.Run("Allowed", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
})
|
||||
@@ -118,8 +221,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Disallowed", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
})
|
||||
@@ -131,8 +233,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("NoAllowlist", func(t *testing.T) {
|
||||
tool := chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
||||
DB: db,
|
||||
tool := chattool.ReadTemplate(org.ID, db, chattool.ReadTemplateOptions{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
input := `{"template_id":"` + t2.ID.String() + `"}`
|
||||
@@ -145,15 +246,16 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
t.Run("CreateWorkspace", func(t *testing.T) {
|
||||
t.Run("Allowed", func(t *testing.T) {
|
||||
createCalled := false
|
||||
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
tool := chattool.CreateWorkspace(org.ID, db, chattool.CreateWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{t1.ID: true} },
|
||||
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
createCalled = true
|
||||
return codersdk.Workspace{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
input := `{"template_id":"` + t1.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8a", Name: "create_workspace", Input: input})
|
||||
require.NoError(t, err)
|
||||
@@ -167,8 +269,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
|
||||
t.Run("Disallowed", func(t *testing.T) {
|
||||
createCalled := false
|
||||
tool := chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
tool := chattool.CreateWorkspace(uuid.Nil, db, chattool.CreateWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
AllowedTemplateIDs: func() map[uuid.UUID]bool { return map[uuid.UUID]bool{uuid.New(): true} },
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, _ codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
@@ -177,6 +278,7 @@ func TestTemplateAllowlistEnforcement(t *testing.T) {
|
||||
return codersdk.Workspace{}, nil
|
||||
},
|
||||
})
|
||||
|
||||
input := `{"template_id":"` + t1.ID.String() + `"}`
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "c8", Name: "create_workspace", Input: input})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
// ReadTemplateOptions configures the read_template tool.
|
||||
type ReadTemplateOptions struct {
|
||||
DB database.Store
|
||||
OwnerID uuid.UUID
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
@@ -26,7 +25,7 @@ type readTemplateArgs struct {
|
||||
// ReadTemplate returns a tool that retrieves details about a specific
|
||||
// template, including its configurable rich parameters. The agent
|
||||
// uses this after list_templates and before create_workspace.
|
||||
func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
func ReadTemplate(organizationID uuid.UUID, db database.Store, options ReadTemplateOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"read_template",
|
||||
"Get details about a workspace template, including its "+
|
||||
@@ -34,7 +33,7 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
"template with list_templates and before creating a "+
|
||||
"workspace with create_workspace.",
|
||||
func(ctx context.Context, args readTemplateArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.DB == nil {
|
||||
if db == nil {
|
||||
return fantasy.NewTextErrorResponse("database is not configured"), nil
|
||||
}
|
||||
|
||||
@@ -53,17 +52,21 @@ func ReadTemplate(options ReadTemplateOptions) fantasy.AgentTool {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
ctx, err = asOwner(ctx, options.DB, options.OwnerID)
|
||||
ctx, err = asOwner(ctx, db, options.OwnerID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
template, err := options.DB.GetTemplateByID(ctx, templateID)
|
||||
template, err := db.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
params, err := options.DB.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
|
||||
if template.OrganizationID != organizationID {
|
||||
return fantasy.NewTextErrorResponse("template not found"), nil
|
||||
}
|
||||
|
||||
params, err := db.GetTemplateVersionParameters(ctx, template.ActiveVersionID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("failed to get template parameters: %w", err).Error(),
|
||||
|
||||
@@ -1172,3 +1172,119 @@ func TestCreateChatNonDefaultOrg(t *testing.T) {
|
||||
}
|
||||
require.True(t, found, "chat should be visible in list")
|
||||
}
|
||||
|
||||
func TestListChats_OrgAdminOnlySeesOwnChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, firstUser := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: func() *codersdk.DeploymentValues {
|
||||
v := coderdtest.DeploymentValues(t)
|
||||
v.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
return v
|
||||
}(),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureMultipleOrganizations: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Set up a chat provider and model config.
|
||||
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
IsDefault: ptr.Ref(true),
|
||||
ContextLimit: ptr.Ref(int64(1000)),
|
||||
CompressionThreshold: ptr.Ref(int32(70)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a second (non-default) org.
|
||||
secondOrg := coderdenttest.CreateOrganization(t, client, coderdenttest.CreateOrganizationOptions{})
|
||||
|
||||
// Create a regular member with agents access in the second org.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(
|
||||
t, client, firstUser.OrganizationID, rbac.RoleAgentsAccess(),
|
||||
)
|
||||
_, err = client.PostOrganizationMember(ctx, secondOrg.ID, member.Username)
|
||||
require.NoError(t, err)
|
||||
memberExp := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Member creates a chat in the second org.
|
||||
memberChat, err := memberExp.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: secondOrg.ID,
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello from member",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, secondOrg.ID, memberChat.OrganizationID)
|
||||
|
||||
// Create an org admin in the second org with agents access.
|
||||
adminClientRaw, _ := coderdtest.CreateAnotherUser(
|
||||
t, client, firstUser.OrganizationID,
|
||||
rbac.ScopedRoleOrgAdmin(secondOrg.ID), rbac.RoleAgentsAccess(),
|
||||
)
|
||||
adminExp := codersdk.NewExperimentalClient(adminClientRaw)
|
||||
|
||||
// Admin creates a chat in the second org.
|
||||
adminChat, err := adminExp.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: secondOrg.ID,
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello from admin",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, secondOrg.ID, adminChat.OrganizationID)
|
||||
|
||||
// Admin lists chats -- should only see their own chat.
|
||||
// TODO: The handler currently filters by OwnerID (the
|
||||
// authenticated user), so org admins cannot see other
|
||||
// users' chats even though RBAC would allow it. If the
|
||||
// handler gains an owner filter parameter, update this
|
||||
// test to verify cross-user visibility.
|
||||
adminChats, err := adminExp.ListChats(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundAdmin, foundMember bool
|
||||
for _, c := range adminChats {
|
||||
if c.ID == adminChat.ID {
|
||||
foundAdmin = true
|
||||
}
|
||||
if c.ID == memberChat.ID {
|
||||
foundMember = true
|
||||
}
|
||||
}
|
||||
require.True(t, foundAdmin, "admin should see own chat")
|
||||
require.False(t, foundMember, "admin should NOT see member chat (OwnerID filter)")
|
||||
|
||||
// Positive control: member can list their own chat.
|
||||
memberChats, err := memberExp.ListChats(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
var memberSeeOwn bool
|
||||
for _, c := range memberChats {
|
||||
if c.ID == memberChat.ID {
|
||||
memberSeeOwn = true
|
||||
}
|
||||
}
|
||||
require.True(t, memberSeeOwn, "member should see own chat")
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ export const Navbar: FC = () => {
|
||||
featureVisibility.connection_log && permissions.viewAnyConnectionLog;
|
||||
const canViewAIBridge =
|
||||
featureVisibility.aibridge && permissions.viewAnyAIBridgeInterception;
|
||||
const canCreateChat = permissions.createChat;
|
||||
|
||||
const uniqueLinks = new Map<string, LinkConfig>();
|
||||
for (const link of appearance.support_links ?? []) {
|
||||
@@ -47,6 +48,7 @@ export const Navbar: FC = () => {
|
||||
canViewAuditLog={canViewAuditLog}
|
||||
canViewConnectionLog={canViewConnectionLog}
|
||||
canViewAIBridge={canViewAIBridge}
|
||||
canCreateChat={canCreateChat}
|
||||
proxyContextValue={proxyContextValue}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -35,6 +35,7 @@ const meta: Meta<typeof NavbarView> = {
|
||||
canViewDeployment: true,
|
||||
canViewHealth: true,
|
||||
canViewOrganizations: true,
|
||||
canCreateChat: true,
|
||||
supportLinks: [],
|
||||
},
|
||||
decorators: [withDashboardProvider],
|
||||
@@ -91,6 +92,18 @@ export const ForMember: Story = {
|
||||
canViewDeployment: false,
|
||||
canViewHealth: false,
|
||||
canViewOrganizations: false,
|
||||
canCreateChat: false,
|
||||
},
|
||||
};
|
||||
|
||||
export const ForMemberWithAgentsAccess: Story = {
|
||||
args: {
|
||||
user: MockUserMember,
|
||||
canViewAuditLog: false,
|
||||
canViewDeployment: false,
|
||||
canViewHealth: false,
|
||||
canViewOrganizations: false,
|
||||
canCreateChat: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ interface NavbarViewProps {
|
||||
canViewConnectionLog: boolean;
|
||||
canViewHealth: boolean;
|
||||
canViewAIBridge: boolean;
|
||||
canCreateChat: boolean;
|
||||
proxyContextValue?: ProxyContextValue;
|
||||
}
|
||||
|
||||
@@ -57,6 +58,7 @@ export const NavbarView: FC<NavbarViewProps> = ({
|
||||
canViewAuditLog,
|
||||
canViewConnectionLog,
|
||||
canViewAIBridge,
|
||||
canCreateChat,
|
||||
proxyContextValue,
|
||||
}) => {
|
||||
const isDev = buildInfo ? isDevBuild(buildInfo) : false;
|
||||
@@ -78,7 +80,7 @@ export const NavbarView: FC<NavbarViewProps> = ({
|
||||
)}
|
||||
</NavLink>
|
||||
|
||||
<NavItems className="ml-4" user={user} />
|
||||
<NavItems className="ml-4" user={user} canCreateChat={canCreateChat} />
|
||||
|
||||
{isPreRelease && buildInfo?.version && (
|
||||
<a
|
||||
@@ -165,9 +167,10 @@ export const NavbarView: FC<NavbarViewProps> = ({
|
||||
interface NavItemsProps {
|
||||
className?: string;
|
||||
user: TypesGen.User;
|
||||
canCreateChat: boolean;
|
||||
}
|
||||
|
||||
const NavItems: FC<NavItemsProps> = ({ className, user }) => {
|
||||
const NavItems: FC<NavItemsProps> = ({ className, user, canCreateChat }) => {
|
||||
const location = useLocation();
|
||||
|
||||
return (
|
||||
@@ -192,7 +195,7 @@ const NavItems: FC<NavItemsProps> = ({ className, user }) => {
|
||||
Templates
|
||||
</NavLink>
|
||||
<TasksNavItem user={user} />
|
||||
<AgentsNavItem />
|
||||
<AgentsNavItem canCreateChat={canCreateChat} />
|
||||
</nav>
|
||||
);
|
||||
};
|
||||
@@ -257,11 +260,12 @@ function idleTasksLabel(count: number) {
|
||||
return `You have ${count} ${count === 1 ? "task" : "tasks"} waiting for input`;
|
||||
}
|
||||
|
||||
const AgentsNavItem: FC = () => {
|
||||
const AgentsNavItem: FC<{ canCreateChat: boolean }> = ({ canCreateChat }) => {
|
||||
const { experiments, buildInfo } = useDashboard();
|
||||
const canSeeAgents = experiments.includes("agents") || isDevBuild(buildInfo);
|
||||
const experimentEnabled =
|
||||
experiments.includes("agents") || isDevBuild(buildInfo);
|
||||
|
||||
if (!canSeeAgents) {
|
||||
if (!experimentEnabled || !canCreateChat) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ describe("useEmptyStateDraft", () => {
|
||||
});
|
||||
expect(localStorage.getItem(emptyInputStorageKey)).toBeNull();
|
||||
|
||||
// Simulate error recovery — re-enable persistence.
|
||||
// Simulate error recovery -- re-enable persistence.
|
||||
act(() => {
|
||||
result.current.resetDraft();
|
||||
});
|
||||
@@ -307,12 +307,14 @@ describe("useFileAttachments persistence", () => {
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
lastModified: number;
|
||||
organizationId: string;
|
||||
}> = {},
|
||||
) => ({
|
||||
fileId: "file-1",
|
||||
fileName: "photo.png",
|
||||
fileType: "image/png",
|
||||
lastModified: 1000,
|
||||
organizationId: "org-1",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
@@ -504,4 +506,102 @@ describe("useFileAttachments persistence", () => {
|
||||
expect(localStorage.getItem(persistedAttachmentsStorageKey)).toBeNull();
|
||||
unmount();
|
||||
});
|
||||
|
||||
it("restores only attachments matching current organization", () => {
|
||||
const entries = [
|
||||
makePersistedEntry({
|
||||
fileId: "f1",
|
||||
fileName: "a.png",
|
||||
organizationId: "org-1",
|
||||
}),
|
||||
makePersistedEntry({
|
||||
fileId: "f2",
|
||||
fileName: "b.png",
|
||||
organizationId: "org-2",
|
||||
}),
|
||||
];
|
||||
localStorage.setItem(
|
||||
persistedAttachmentsStorageKey,
|
||||
JSON.stringify(entries),
|
||||
);
|
||||
|
||||
const { result, unmount } = renderFileAttachments();
|
||||
|
||||
expect(result.current.attachments).toHaveLength(1);
|
||||
expect(result.current.attachments[0].name).toBe("a.png");
|
||||
|
||||
// localStorage should be pruned to only the matching org.
|
||||
const stored = JSON.parse(
|
||||
localStorage.getItem(persistedAttachmentsStorageKey)!,
|
||||
);
|
||||
expect(stored).toHaveLength(1);
|
||||
expect(stored[0].fileId).toBe("f1");
|
||||
unmount();
|
||||
});
|
||||
|
||||
it("prunes legacy entries without organizationId", () => {
|
||||
const legacy = {
|
||||
fileId: "old-file",
|
||||
fileName: "legacy.png",
|
||||
fileType: "image/png",
|
||||
lastModified: 1000,
|
||||
// No organizationId field -- simulates pre-org-scoping data.
|
||||
};
|
||||
localStorage.setItem(
|
||||
persistedAttachmentsStorageKey,
|
||||
JSON.stringify([legacy]),
|
||||
);
|
||||
|
||||
const { result, unmount } = renderFileAttachments();
|
||||
|
||||
expect(result.current.attachments).toHaveLength(0);
|
||||
expect(localStorage.getItem(persistedAttachmentsStorageKey)).toBeNull();
|
||||
unmount();
|
||||
});
|
||||
|
||||
it("skips restoration when organizationId is undefined", () => {
|
||||
const entry = makePersistedEntry();
|
||||
localStorage.setItem(
|
||||
persistedAttachmentsStorageKey,
|
||||
JSON.stringify([entry]),
|
||||
);
|
||||
|
||||
const { result, unmount } = renderHook(() =>
|
||||
useFileAttachments(undefined, { persist: true }),
|
||||
);
|
||||
|
||||
// Should not restore -- org not yet known.
|
||||
expect(result.current.attachments).toHaveLength(0);
|
||||
// Should NOT prune -- org unknown, so leave storage alone.
|
||||
expect(localStorage.getItem(persistedAttachmentsStorageKey)).not.toBeNull();
|
||||
unmount();
|
||||
});
|
||||
|
||||
it("persists organizationId with attachment metadata", async () => {
|
||||
const { API } = await import("#/api/api");
|
||||
vi.spyOn(API.experimental, "uploadChatFile").mockResolvedValue({
|
||||
id: "new-file-id",
|
||||
});
|
||||
vi.spyOn(globalThis, "fetch").mockResolvedValue(new Response());
|
||||
|
||||
const { result, unmount } = renderFileAttachments();
|
||||
|
||||
const file = new File(["hello"], "test.png", { type: "image/png" });
|
||||
|
||||
act(() => {
|
||||
result.current.handleAttach([file]);
|
||||
});
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const state = result.current.uploadStates.get(file);
|
||||
expect(state?.status).toBe("uploaded");
|
||||
});
|
||||
|
||||
const stored = JSON.parse(
|
||||
localStorage.getItem(persistedAttachmentsStorageKey)!,
|
||||
);
|
||||
expect(stored).toHaveLength(1);
|
||||
expect(stored[0].organizationId).toBe("org-1");
|
||||
unmount();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,6 +9,13 @@ import {
|
||||
import { withDashboardProvider } from "#/testHelpers/storybook";
|
||||
import { AgentCreateForm } from "./AgentCreateForm";
|
||||
|
||||
// Query key used by permittedOrganizations() in the form.
|
||||
const permittedOrgsKey = [
|
||||
"organizations",
|
||||
"permitted",
|
||||
{ object: { resource_type: "chat" }, action: "create" },
|
||||
];
|
||||
|
||||
const modelConfigID = "model-config-1";
|
||||
|
||||
const modelOptions = [
|
||||
@@ -231,6 +238,7 @@ export const PreservesAttachmentsOnFailedSend: Story = {
|
||||
fileName: "photo.png",
|
||||
fileType: "image/png",
|
||||
lastModified: 1000,
|
||||
organizationId: "my-organization-id",
|
||||
},
|
||||
]),
|
||||
);
|
||||
@@ -319,6 +327,29 @@ export const WithOrganizationPicker: Story = {
|
||||
parameters: {
|
||||
showOrganizations: true,
|
||||
organizations: [MockDefaultOrganization, MockOrganization2],
|
||||
queries: [
|
||||
{
|
||||
key: permittedOrgsKey,
|
||||
data: [MockDefaultOrganization, MockOrganization2],
|
||||
},
|
||||
],
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
// Verify the org picker rendered (component didn't crash).
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
canvas.getByTestId("organization-autocomplete"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
// Type into the chat input to trigger re-renders. If the
|
||||
// permittedOrgs fallback is referentially unstable, this
|
||||
// causes a render cascade that hits React's update limit.
|
||||
const input = canvas.getByTestId("chat-message-input");
|
||||
await userEvent.click(input);
|
||||
await userEvent.keyboard("hello world");
|
||||
// The org picker should still be present after typing.
|
||||
expect(canvas.getByTestId("organization-autocomplete")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { type FC, useEffect, useRef, useState } from "react";
|
||||
import { type FC, useEffect, useEffectEvent, useRef, useState } from "react";
|
||||
import { useQuery } from "react-query";
|
||||
import { Link } from "react-router";
|
||||
import { toast } from "sonner";
|
||||
import { isApiError } from "#/api/errors";
|
||||
import { permittedOrganizations } from "#/api/queries/organizations";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { Alert, AlertDescription } from "#/components/Alert/Alert";
|
||||
import { ErrorAlert } from "#/components/Alert/ErrorAlert";
|
||||
@@ -193,9 +195,9 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
const stored = localStorage.getItem(selectedWorkspaceIdStorageKey);
|
||||
if (!stored) return null;
|
||||
|
||||
// If workspaces haven't loaded yet, keep the stored value.
|
||||
// It will be re-validated once the list arrives via
|
||||
// filteredWorkspaces clearing the selection if stale.
|
||||
// The stored value is kept optimistically until workspaces
|
||||
// load. effectiveWorkspaceId (computed after render) drops
|
||||
// it if it doesn't match the current org's workspaces.
|
||||
if (workspaceOptions.length === 0) return stored;
|
||||
|
||||
// Validate the stored workspace still exists and belongs
|
||||
@@ -257,12 +259,6 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
lastUsedModelID,
|
||||
]);
|
||||
|
||||
// Keep a mutable ref to selectedWorkspaceId and selectedModel so
|
||||
// that the onSend callback always sees the latest values without
|
||||
// the shared input component re-rendering on every change.
|
||||
const selectedWorkspaceIdRef = useRef(selectedWorkspaceId);
|
||||
const selectedModelRef = useRef(selectedModel);
|
||||
const organizationIdRef = useRef(organizationId);
|
||||
const [userMCPServerIds, setUserMCPServerIds] = useState<string[] | null>(
|
||||
null,
|
||||
);
|
||||
@@ -276,13 +272,6 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
}
|
||||
return getDefaultMCPSelection(mcpServers ?? []);
|
||||
})();
|
||||
const selectedMCPServerIdsRef = useRef(effectiveMCPServerIds);
|
||||
useEffect(() => {
|
||||
selectedWorkspaceIdRef.current = selectedWorkspaceId;
|
||||
selectedModelRef.current = selectedModel;
|
||||
selectedMCPServerIdsRef.current = effectiveMCPServerIds;
|
||||
organizationIdRef.current = organizationId;
|
||||
});
|
||||
const handleWorkspaceChange = (value: string | null) => {
|
||||
if (value === null) {
|
||||
setSelectedWorkspaceId(null);
|
||||
@@ -298,27 +287,47 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
setUserSelectedModel(value);
|
||||
};
|
||||
|
||||
const handleSend = async (message: string, fileIDs?: string[]) => {
|
||||
submitDraft();
|
||||
await onCreateChat({
|
||||
message,
|
||||
fileIDs,
|
||||
workspaceId: selectedWorkspaceIdRef.current ?? undefined,
|
||||
model: selectedModelRef.current || undefined,
|
||||
organizationId: organizationIdRef.current,
|
||||
mcpServerIds:
|
||||
selectedMCPServerIdsRef.current.length > 0
|
||||
? [...selectedMCPServerIdsRef.current]
|
||||
: undefined,
|
||||
}).catch((err) => {
|
||||
// Re-enable draft persistence so the user can edit
|
||||
// and retry after a failed send attempt, then rethrow
|
||||
// so callers (handleSendWithAttachments) can preserve
|
||||
// attachments on failure.
|
||||
resetDraft();
|
||||
throw err;
|
||||
});
|
||||
};
|
||||
const isForbidden = !canCreateChat;
|
||||
|
||||
// Filter workspaces by the selected organization. We use
|
||||
// client-side filtering of the full "owner:me" fetch rather
|
||||
// than re-querying with an org filter because it avoids
|
||||
// extra loading/error states on org change. The full list is
|
||||
// already small (user's own workspaces) and limit: 0
|
||||
// guarantees completeness. If workspace counts grow large
|
||||
// enough to warrant pagination, this should switch to a
|
||||
// server-side organization:<name> query filter.
|
||||
const filteredWorkspaces =
|
||||
showOrganizations && selectedOrg
|
||||
? workspaceOptions.filter((ws) => ws.organization_id === selectedOrg.id)
|
||||
: workspaceOptions;
|
||||
|
||||
const effectiveWorkspaceId =
|
||||
selectedWorkspaceId !== null &&
|
||||
(isWorkspacesLoading ||
|
||||
filteredWorkspaces.some((ws) => ws.id === selectedWorkspaceId))
|
||||
? selectedWorkspaceId
|
||||
: null;
|
||||
|
||||
const handleSend = useEffectEvent(
|
||||
async (message: string, fileIDs?: string[]) => {
|
||||
submitDraft();
|
||||
await onCreateChat({
|
||||
message,
|
||||
fileIDs,
|
||||
workspaceId: effectiveWorkspaceId ?? undefined,
|
||||
model: selectedModel || undefined,
|
||||
organizationId,
|
||||
mcpServerIds:
|
||||
effectiveMCPServerIds.length > 0
|
||||
? [...effectiveMCPServerIds]
|
||||
: undefined,
|
||||
}).catch((err) => {
|
||||
resetDraft();
|
||||
throw err;
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
const {
|
||||
attachments,
|
||||
@@ -357,20 +366,42 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const isForbidden = !canCreateChat;
|
||||
const permittedOrgsQuery = useQuery({
|
||||
...permittedOrganizations({
|
||||
object: { resource_type: "chat" },
|
||||
action: "create",
|
||||
}),
|
||||
enabled: showOrganizations,
|
||||
});
|
||||
const permittedOrgs = permittedOrgsQuery.data ?? organizations;
|
||||
|
||||
// Filter workspaces by the selected organization. We use
|
||||
// client-side filtering of the full "owner:me" fetch rather
|
||||
// than re-querying with an org filter because it avoids
|
||||
// extra loading/error states on org change. The full list is
|
||||
// already small (user's own workspaces) and limit: 0
|
||||
// guarantees completeness. If workspace counts grow large
|
||||
// enough to warrant pagination, this should switch to a
|
||||
// server-side organization:<name> query filter.
|
||||
const filteredWorkspaces =
|
||||
showOrganizations && selectedOrg
|
||||
? workspaceOptions.filter((ws) => ws.organization_id === selectedOrg.id)
|
||||
: workspaceOptions;
|
||||
// Reconcile selectedOrg when permission filtering removes it.
|
||||
// Only pure state setters run during render; side effects
|
||||
// (localStorage, blob URL cleanup) run in the effect below.
|
||||
const [prevPermittedOrgs, setPrevPermittedOrgs] = useState(permittedOrgs);
|
||||
const [orgWasAdjusted, setOrgWasAdjusted] = useState(false);
|
||||
if (permittedOrgs !== prevPermittedOrgs) {
|
||||
setPrevPermittedOrgs(permittedOrgs);
|
||||
if (selectedOrg && !permittedOrgs.some((o) => o.id === selectedOrg.id)) {
|
||||
setSelectedOrg(permittedOrgs[0] ?? null);
|
||||
setOrgWasAdjusted(true);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up workspace and attachment state after a programmatic
|
||||
// org change from permission filtering. These calls have side
|
||||
// effects (localStorage, blob URL revocation) that must not
|
||||
// run during render.
|
||||
const onOrgAdjusted = useEffectEvent(() => {
|
||||
handleWorkspaceChange(null);
|
||||
resetAttachments();
|
||||
});
|
||||
useEffect(() => {
|
||||
if (orgWasAdjusted) {
|
||||
setOrgWasAdjusted(false);
|
||||
onOrgAdjusted();
|
||||
}
|
||||
}, [orgWasAdjusted]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -399,28 +430,33 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
)
|
||||
) : null}
|
||||
{workspacesError != null && <ErrorAlert error={workspacesError} />}
|
||||
{showOrganizations && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label htmlFor="organization">Organization</Label>
|
||||
<OrganizationAutocomplete
|
||||
id="organization"
|
||||
required
|
||||
value={selectedOrg}
|
||||
options={organizations}
|
||||
onChange={(newOrg) => {
|
||||
const orgChanged = newOrg?.id !== selectedOrg?.id;
|
||||
if (orgChanged && attachments.length > 0) {
|
||||
setPendingOrgChange(newOrg);
|
||||
return;
|
||||
}
|
||||
if (orgChanged) {
|
||||
handleWorkspaceChange(null);
|
||||
}
|
||||
setSelectedOrg(newOrg);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{permittedOrgsQuery.error != null && (
|
||||
<ErrorAlert error={permittedOrgsQuery.error} />
|
||||
)}
|
||||
{showOrganizations &&
|
||||
!permittedOrgsQuery.isLoading &&
|
||||
permittedOrgs.length > 1 && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label htmlFor="organization">Organization</Label>
|
||||
<OrganizationAutocomplete
|
||||
id="organization"
|
||||
required
|
||||
value={selectedOrg}
|
||||
options={permittedOrgs}
|
||||
onChange={(newOrg) => {
|
||||
const orgChanged = newOrg?.id !== selectedOrg?.id;
|
||||
if (orgChanged && attachments.length > 0) {
|
||||
setPendingOrgChange(newOrg);
|
||||
return;
|
||||
}
|
||||
if (orgChanged) {
|
||||
handleWorkspaceChange(null);
|
||||
}
|
||||
setSelectedOrg(newOrg);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<AgentChatInput
|
||||
onSend={handleSendWithAttachments}
|
||||
placeholder="Ask Coder to build, fix bugs, or explore your project..."
|
||||
@@ -449,7 +485,7 @@ export const AgentCreateForm: FC<AgentCreateFormProps> = ({
|
||||
}}
|
||||
onMCPAuthComplete={onMCPAuthComplete}
|
||||
workspaceOptions={filteredWorkspaces}
|
||||
selectedWorkspaceId={selectedWorkspaceId}
|
||||
selectedWorkspaceId={effectiveWorkspaceId}
|
||||
onWorkspaceChange={handleWorkspaceChange}
|
||||
isWorkspaceLoading={isWorkspacesLoading}
|
||||
/>
|
||||
|
||||
@@ -2,7 +2,7 @@ import {
|
||||
type Dispatch,
|
||||
type SetStateAction,
|
||||
useEffect,
|
||||
useRef,
|
||||
useEffectEvent,
|
||||
useState,
|
||||
} from "react";
|
||||
import { API } from "#/api/api";
|
||||
@@ -21,18 +21,33 @@ interface PersistedAttachment {
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
lastModified: number;
|
||||
organizationId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore previously persisted attachments from localStorage.
|
||||
* Creates synthetic File objects (empty blobs with correct metadata)
|
||||
* and populates the corresponding Maps so the UI can render them.
|
||||
*
|
||||
* Only attachments matching `currentOrgId` are returned. Entries
|
||||
* belonging to a different organization are pruned from storage.
|
||||
*/
|
||||
function restorePersistedAttachments(): {
|
||||
function restorePersistedAttachments(currentOrgId: string): {
|
||||
attachments: File[];
|
||||
uploadStates: Map<File, UploadState>;
|
||||
previewUrls: Map<File, string>;
|
||||
} {
|
||||
// When the org ID is not yet known (e.g. still loading), skip
|
||||
// restoration entirely so we don't accidentally prune valid
|
||||
// entries. The initializer only runs once, so the caller must
|
||||
// ensure the org ID is available before mounting the hook.
|
||||
if (!currentOrgId) {
|
||||
return {
|
||||
attachments: [],
|
||||
uploadStates: new Map(),
|
||||
previewUrls: new Map(),
|
||||
};
|
||||
}
|
||||
const stored = localStorage.getItem(persistedAttachmentsStorageKey);
|
||||
if (!stored) {
|
||||
return {
|
||||
@@ -43,11 +58,25 @@ function restorePersistedAttachments(): {
|
||||
}
|
||||
try {
|
||||
const persisted: PersistedAttachment[] = JSON.parse(stored);
|
||||
const matched = persisted.filter((p) => p.organizationId === currentOrgId);
|
||||
|
||||
// Prune entries that don't match the current org.
|
||||
if (matched.length !== persisted.length) {
|
||||
if (matched.length > 0) {
|
||||
localStorage.setItem(
|
||||
persistedAttachmentsStorageKey,
|
||||
JSON.stringify(matched),
|
||||
);
|
||||
} else {
|
||||
localStorage.removeItem(persistedAttachmentsStorageKey);
|
||||
}
|
||||
}
|
||||
|
||||
const attachments: File[] = [];
|
||||
const uploadStates = new Map<File, UploadState>();
|
||||
const previewUrls = new Map<File, string>();
|
||||
|
||||
for (const p of persisted) {
|
||||
for (const p of matched) {
|
||||
if (!p.fileId || !p.fileName) continue;
|
||||
// Synthetic File used as a Map key only. Its content is
|
||||
// never read because the existing file_id is reused at
|
||||
@@ -72,7 +101,11 @@ function restorePersistedAttachments(): {
|
||||
}
|
||||
}
|
||||
|
||||
function addPersistedAttachment(file: File, fileId: string) {
|
||||
function addPersistedAttachment(
|
||||
file: File,
|
||||
fileId: string,
|
||||
organizationId: string,
|
||||
) {
|
||||
const stored = localStorage.getItem(persistedAttachmentsStorageKey);
|
||||
let persisted: PersistedAttachment[];
|
||||
try {
|
||||
@@ -85,6 +118,7 @@ function addPersistedAttachment(file: File, fileId: string) {
|
||||
fileName: file.name,
|
||||
fileType: file.type,
|
||||
lastModified: file.lastModified,
|
||||
organizationId,
|
||||
});
|
||||
localStorage.setItem(
|
||||
persistedAttachmentsStorageKey,
|
||||
@@ -141,7 +175,7 @@ export function useFileAttachments(
|
||||
// when persistence is enabled. Computed once on first render.
|
||||
const [restored] = useState(() =>
|
||||
persist
|
||||
? restorePersistedAttachments()
|
||||
? restorePersistedAttachments(organizationId ?? "")
|
||||
: {
|
||||
attachments: [] as File[],
|
||||
uploadStates: new Map<File, UploadState>(),
|
||||
@@ -157,16 +191,13 @@ export function useFileAttachments(
|
||||
);
|
||||
|
||||
// Revoke blob URLs on unmount to prevent memory leaks.
|
||||
const previewUrlsRef = useRef(previewUrls);
|
||||
useEffect(() => {
|
||||
previewUrlsRef.current = previewUrls;
|
||||
const revokePreviewUrls = useEffectEvent(() => {
|
||||
for (const [, url] of previewUrls) {
|
||||
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
|
||||
}
|
||||
});
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
for (const [, url] of previewUrlsRef.current) {
|
||||
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
|
||||
}
|
||||
};
|
||||
return () => revokePreviewUrls();
|
||||
}, []);
|
||||
|
||||
const startUpload = (file: File) => {
|
||||
@@ -179,6 +210,10 @@ export function useFileAttachments(
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const shouldPersist = persist && Boolean(organizationId);
|
||||
const isImage = file.type.startsWith("image/");
|
||||
|
||||
setUploadStates((prev) => new Map(prev).set(file, { status: "uploading" }));
|
||||
void (async () => {
|
||||
try {
|
||||
@@ -192,14 +227,14 @@ export function useFileAttachments(
|
||||
fileId: result.id,
|
||||
}),
|
||||
);
|
||||
if (persist) {
|
||||
addPersistedAttachment(file, result.id);
|
||||
if (shouldPersist) {
|
||||
addPersistedAttachment(file, result.id, organizationId!);
|
||||
}
|
||||
// Pre-warm the browser HTTP cache for images so the
|
||||
// timeline can render them instantly after send. We
|
||||
// intentionally skip text attachments because the
|
||||
// composer already has the text content locally.
|
||||
if (file.type.startsWith("image/")) {
|
||||
if (isImage) {
|
||||
void fetch(`/api/experimental/chats/files/${result.id}`);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
@@ -310,9 +345,7 @@ export function useFileAttachments(
|
||||
};
|
||||
|
||||
const resetAttachments = () => {
|
||||
for (const [, url] of previewUrlsRef.current) {
|
||||
if (url.startsWith("blob:")) URL.revokeObjectURL(url);
|
||||
}
|
||||
revokePreviewUrls();
|
||||
setPreviewUrls(new Map());
|
||||
setTextContents(new Map());
|
||||
setUploadStates(new Map());
|
||||
|
||||
Reference in New Issue
Block a user