fix: exclude subagent chats from sidebar pagination (#24404)

GetChats now returns only root chats (parent_chat_id IS NULL).
A new GetChildChatsByParentIDs query fetches children for visible
roots and embeds them in each parent's Children field. The
singular getChat endpoint does the same.

Archive invariant is one-way: parent archived implies child
archived. Parent archive/unarchive cascades via root_chat_id.
Individual child archive is permitted; child unarchive while the
parent is archived is rejected atomically (row lock on child,
re-read parent inside the transaction). Embedded children are
filtered by the caller's archive state so individually-archived
children stay hidden from active-parent views.

Gitsync MarkStale uses GetChatsByWorkspaceIDs directly;
MarkStaleParams.OwnerID removed (dead after the switch).

Frontend: buildChatTree reads from the embedded children field,
WebSocket handlers route child events into the parent's children
array, and archiving a child strips it from the parent cache.
This commit is contained in:
Mathias Fredriksson
2026-04-20 13:19:59 +03:00
committed by GitHub
parent df429b7f60
commit fc2493780f
30 changed files with 1514 additions and 225 deletions
+53 -9
View File
@@ -1609,6 +1609,11 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
parentChatID := c.ParentChatID.UUID
chat.ParentChatID = &parentChatID
}
// Always initialize Children to an empty slice so the JSON
// field serializes as [] rather than null. Root chats may
// later have children populated; child chats remain empty
// because nesting depth is capped at 1.
chat.Children = []codersdk.Chat{}
switch {
case c.RootChatID.Valid:
rootChatID := c.RootChatID.UUID
@@ -1756,19 +1761,21 @@ func ChatDebugStep(s database.ChatDebugStep) codersdk.ChatDebugStep {
}
}
// ChatRows converts a slice of database.GetChatsRow (which embeds
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
// from the provided map. When diffStatusesByChatID is non-nil,
// chats without an entry receive an empty DiffStatus.
func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]database.ChatDiffStatus) []codersdk.Chat {
result := make([]codersdk.Chat, len(rows))
for i, row := range rows {
diffStatus, ok := diffStatusesByChatID[row.Chat.ID]
// ChildChatRows converts child chat rows to codersdk.Chat values,
// resolving diff statuses from the shared map. When diffStatuses
// is non-nil, children without an entry receive an empty DiffStatus.
func ChildChatRows(
children []database.GetChildChatsByParentIDsRow,
diffStatuses map[uuid.UUID]database.ChatDiffStatus,
) []codersdk.Chat {
result := make([]codersdk.Chat, len(children))
for i, row := range children {
diffStatus, ok := diffStatuses[row.Chat.ID]
if ok {
result[i] = Chat(row.Chat, &diffStatus, nil)
} else {
result[i] = Chat(row.Chat, nil, nil)
if diffStatusesByChatID != nil {
if diffStatuses != nil {
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
result[i].DiffStatus = &emptyDiffStatus
}
@@ -1778,6 +1785,43 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da
return result
}
// ChatRowsWithChildren converts root chat rows and their child rows
// into codersdk.Chat values with children embedded under each parent.
// Both root and child diff statuses are resolved from the shared map.
func ChatRowsWithChildren(
roots []database.GetChatsRow,
children []database.GetChildChatsByParentIDsRow,
diffStatuses map[uuid.UUID]database.ChatDiffStatus,
) []codersdk.Chat {
// Group children by parent ID.
childrenByParent := make(map[uuid.UUID][]database.GetChildChatsByParentIDsRow, len(children))
for _, row := range children {
parentID := row.Chat.ParentChatID.UUID
childrenByParent[parentID] = append(childrenByParent[parentID], row)
}
result := make([]codersdk.Chat, len(roots))
for i, row := range roots {
diffStatus, ok := diffStatuses[row.Chat.ID]
if ok {
result[i] = Chat(row.Chat, &diffStatus, nil)
} else {
result[i] = Chat(row.Chat, nil, nil)
if diffStatuses != nil {
emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil)
result[i].DiffStatus = &emptyDiffStatus
}
}
result[i].HasUnread = row.HasUnread
// Embed child chats.
if childRows, ok := childrenByParent[row.Chat.ID]; ok {
result[i].Children = ChildChatRows(childRows, diffStatuses)
}
}
return result
}
// ChatDiffStatus converts a database.ChatDiffStatus to a
// codersdk.ChatDiffStatus. When status is nil an empty value
// containing only the chatID is returned.
+1 -1
View File
@@ -856,7 +856,7 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
v := reflect.ValueOf(got)
typ := v.Type()
// HasUnread is populated by ChatRows (which joins the
// HasUnread is populated by ChatRowsWithChildren (which joins the
// read-cursor query), not by Chat. Warnings is a transient
// field populated by handlers, not the converter. Both are
// expected to remain zero here.
+8
View File
@@ -2944,6 +2944,14 @@ func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Ti
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
}
func (q *querier) GetChildChatsByParentIDs(ctx context.Context, arg database.GetChildChatsByParentIDsParams) ([]database.GetChildChatsByParentIDsRow, error) {
// Each child is independently authorized via post-filter.
// The handler calls this after GetChats already authorized
// the parent chats, but we still verify read access on
// every child row for defense in depth.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChildChatsByParentIDs)(ctx, arg)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
+21
View File
@@ -820,6 +820,27 @@ func (s *MethodTestSuite) TestChats() {
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("GetChildChatsByParentIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
parentA := testutil.Fake(s.T(), faker, database.Chat{})
parentB := testutil.Fake(s.T(), faker, database.Chat{})
childA := testutil.Fake(s.T(), faker, database.Chat{
ParentChatID: uuid.NullUUID{UUID: parentA.ID, Valid: true},
})
childB := testutil.Fake(s.T(), faker, database.Chat{
ParentChatID: uuid.NullUUID{UUID: parentB.ID, Valid: true},
})
parentIDs := []uuid.UUID{parentA.ID, parentB.ID}
params := database.GetChildChatsByParentIDsParams{
ParentIds: parentIDs,
Archived: sql.NullBool{Bool: false, Valid: true},
}
rows := []database.GetChildChatsByParentIDsRow{
{Chat: childA},
{Chat: childB},
}
dbm.EXPECT().GetChildChatsByParentIDs(gomock.Any(), params).Return(rows, nil).AnyTimes()
check.Args(params).Asserts(childA, policy.ActionRead, childB, policy.ActionRead).Returns(rows)
}))
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes()
@@ -1456,6 +1456,14 @@ func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfte
return r0, r1
}
func (m queryMetricsStore) GetChildChatsByParentIDs(ctx context.Context, arg database.GetChildChatsByParentIDsParams) ([]database.GetChildChatsByParentIDsRow, error) {
start := time.Now()
r0, r1 := m.s.GetChildChatsByParentIDs(ctx, arg)
m.queryLatencies.WithLabelValues("GetChildChatsByParentIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChildChatsByParentIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
+15
View File
@@ -2687,6 +2687,21 @@ func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
}
// GetChildChatsByParentIDs mocks base method.
func (m *MockStore) GetChildChatsByParentIDs(ctx context.Context, arg database.GetChildChatsByParentIDsParams) ([]database.GetChildChatsByParentIDsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChildChatsByParentIDs", ctx, arg)
ret0, _ := ret[0].([]database.GetChildChatsByParentIDsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChildChatsByParentIDs indicates an expected call of GetChildChatsByParentIDs.
func (mr *MockStoreMockRecorder) GetChildChatsByParentIDs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChildChatsByParentIDs", reflect.TypeOf((*MockStore)(nil).GetChildChatsByParentIDs), ctx, arg)
}
// GetConnectionLogsOffset mocks base method.
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
+4
View File
@@ -182,6 +182,10 @@ func (r GetChatsRow) RBACObject() rbac.Object {
return r.Chat.RBACObject()
}
func (r GetChildChatsByParentIDsRow) RBACObject() rbac.Object {
return r.Chat.RBACObject()
}
func (c ChatFile) RBACObject() rbac.Object {
return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID)
}
+5
View File
@@ -349,6 +349,11 @@ type sqlcQuerier interface {
// snapshot collection. Uses updated_at so that long-running chats
// still appear in each snapshot window while they are active.
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
// Fetches child chats of the given parents, optionally filtered by
// archive state (NULL = all, true/false = match). The archive
// invariant (parent archived implies child archived) is enforced
// at write time, not here.
GetChildChatsByParentIDs(ctx context.Context, arg GetChildChatsByParentIDsParams) ([]GetChildChatsByParentIDsRow, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
+94
View File
@@ -6620,6 +6620,11 @@ WHERE
WHEN $4::jsonb IS NOT NULL THEN chats.labels @> $4::jsonb
ELSE true
END
-- Paginate over root chats only. Children are fetched
-- separately via GetChildChatsByParentIDs and embedded under
-- each parent. Other callers that need the full set should
-- use a narrower query (e.g. GetChatsByWorkspaceIDs).
AND chats.parent_chat_id IS NULL
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
@@ -6838,6 +6843,95 @@ func (q *sqlQuerier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time
return items, nil
}
const getChildChatsByParentIDs = `-- name: GetChildChatsByParentIDs :many
SELECT
chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools, chats.organization_id, chats.plan_mode, chats.client_type,
EXISTS (
SELECT 1 FROM chat_messages cm
WHERE cm.chat_id = chats.id
AND cm.role = 'assistant'
AND cm.deleted = false
AND cm.id > COALESCE(chats.last_read_message_id, 0)
) AS has_unread
FROM
chats
WHERE
chats.parent_chat_id = ANY($1 :: uuid[])
AND CASE
WHEN $2 :: boolean IS NULL THEN true
ELSE chats.archived = $2 :: boolean
END
ORDER BY
chats.created_at ASC,
chats.id ASC
`
type GetChildChatsByParentIDsParams struct {
ParentIds []uuid.UUID `db:"parent_ids" json:"parent_ids"`
Archived sql.NullBool `db:"archived" json:"archived"`
}
type GetChildChatsByParentIDsRow struct {
Chat Chat `db:"chat" json:"chat"`
HasUnread bool `db:"has_unread" json:"has_unread"`
}
// Fetches child chats of the given parents, optionally filtered by
// archive state (NULL = all, true/false = match). The archive
// invariant (parent archived implies child archived) is enforced
// at write time, not here.
func (q *sqlQuerier) GetChildChatsByParentIDs(ctx context.Context, arg GetChildChatsByParentIDsParams) ([]GetChildChatsByParentIDsRow, error) {
rows, err := q.db.QueryContext(ctx, getChildChatsByParentIDs, pq.Array(arg.ParentIds), arg.Archived)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChildChatsByParentIDsRow
for rows.Next() {
var i GetChildChatsByParentIDsRow
if err := rows.Scan(
&i.Chat.ID,
&i.Chat.OwnerID,
&i.Chat.WorkspaceID,
&i.Chat.Title,
&i.Chat.Status,
&i.Chat.WorkerID,
&i.Chat.StartedAt,
&i.Chat.HeartbeatAt,
&i.Chat.CreatedAt,
&i.Chat.UpdatedAt,
&i.Chat.ParentChatID,
&i.Chat.RootChatID,
&i.Chat.LastModelConfigID,
&i.Chat.Archived,
&i.Chat.LastError,
&i.Chat.Mode,
pq.Array(&i.Chat.MCPServerIDs),
&i.Chat.Labels,
&i.Chat.BuildID,
&i.Chat.AgentID,
&i.Chat.PinOrder,
&i.Chat.LastReadMessageID,
&i.Chat.LastInjectedContext,
&i.Chat.DynamicTools,
&i.Chat.OrganizationID,
&i.Chat.PlanMode,
&i.Chat.ClientType,
&i.HasUnread,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
+31
View File
@@ -373,6 +373,11 @@ WHERE
WHEN sqlc.narg('label_filter')::jsonb IS NOT NULL THEN chats.labels @> sqlc.narg('label_filter')::jsonb
ELSE true
END
-- Paginate over root chats only. Children are fetched
-- separately via GetChildChatsByParentIDs and embedded under
-- each parent. Other callers that need the full set should
-- use a narrower query (e.g. GetChatsByWorkspaceIDs).
AND chats.parent_chat_id IS NULL
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
@@ -390,6 +395,32 @@ LIMIT
-- Default to 50 to prevent accidental excessively large queries.
COALESCE(NULLIF(@limit_opt :: int, 0), 50);
-- name: GetChildChatsByParentIDs :many
-- Fetches child chats of the given parents, optionally filtered by
-- archive state (NULL = all, true/false = match). The archive
-- invariant (parent archived implies child archived) is enforced
-- at write time, not here.
SELECT
sqlc.embed(chats),
EXISTS (
SELECT 1 FROM chat_messages cm
WHERE cm.chat_id = chats.id
AND cm.role = 'assistant'
AND cm.deleted = false
AND cm.id > COALESCE(chats.last_read_message_id, 0)
) AS has_unread
FROM
chats
WHERE
chats.parent_chat_id = ANY(@parent_ids :: uuid[])
AND CASE
WHEN sqlc.narg('archived') :: boolean IS NULL THEN true
ELSE chats.archived = sqlc.narg('archived') :: boolean
END
ORDER BY
chats.created_at ASC,
chats.id ASC;
-- name: InsertChat :one
INSERT INTO chats (
organization_id,
+114 -6
View File
@@ -336,13 +336,39 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
return
}
// Extract the Chat objects for diff status lookup.
dbChats := make([]database.Chat, len(chatRows))
// Collect root chat IDs so we can fetch their children.
rootIDs := make([]uuid.UUID, len(chatRows))
for i, row := range chatRows {
dbChats[i] = row.Chat
rootIDs[i] = row.Chat.ID
}
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, dbChats)
// Embed children matching the caller's archive filter so
// sidebar views don't surface state-mismatched rows.
var childRows []database.GetChildChatsByParentIDsRow
if len(rootIDs) > 0 {
childRows, err = api.Database.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
ParentIds: rootIDs,
Archived: searchParams.Archived,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list child chats.",
Detail: err.Error(),
})
return
}
}
// Collect all chat objects (root + child) for diff status lookup.
allChats := make([]database.Chat, 0, len(chatRows)+len(childRows))
for _, row := range chatRows {
allChats = append(allChats, row.Chat)
}
for _, row := range childRows {
allChats = append(allChats, row.Chat)
}
diffStatusesByChatID, err := api.getChatDiffStatusesByChatID(ctx, allChats)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list chats.",
@@ -351,7 +377,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
return
}
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatRows(chatRows, diffStatusesByChatID))
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatRowsWithChildren(chatRows, childRows, diffStatusesByChatID))
}
func (api *API) getChatDiffStatusesByChatID(
@@ -1506,7 +1532,41 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
// Hydrate file metadata for all files linked to this chat.
chatFiles := api.fetchChatFileMetadata(ctx, chat.ID)
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus, chatFiles))
sdkChat := db2sdk.Chat(chat, diffStatus, chatFiles)
// For root chats, embed children so callers get a complete
// tree in a single response.
if !chat.ParentChatID.Valid {
// Embed children matching the parent's archive state.
childRows, err := api.Database.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
ParentIds: []uuid.UUID{chat.ID},
Archived: sql.NullBool{Bool: chat.Archived, Valid: true},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to fetch child chats.",
Detail: err.Error(),
})
return
}
// Look up diff statuses for children.
childChats := make([]database.Chat, len(childRows))
for i, row := range childRows {
childChats[i] = row.Chat
}
childDiffStatuses, err := api.getChatDiffStatusesByChatID(ctx, childChats)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to fetch child chat diff statuses.",
Detail: err.Error(),
})
return
}
sdkChat.Children = db2sdk.ChildChatRows(childRows, childDiffStatuses)
}
httpapi.Write(ctx, rw, http.StatusOK, sdkChat)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
@@ -1907,6 +1967,16 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
return
}
// Archive invariant is one-way: parent archived implies
// child archived. Parent archive/unarchive cascade via
// root_chat_id; individual child archive is permitted;
// child unarchive while the parent is archived is rejected
// (enforced atomically in chatd.Server.UnarchiveChat).
if chat.ParentChatID.Valid && !archived {
if done := api.writeChildUnarchiveGuard(ctx, rw, chat); done {
return
}
}
var err error
// Use chatDaemon when available so it can interrupt active
// processing before broadcasting archive state. Fall back to
@@ -1925,6 +1995,12 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
}
}
if err != nil {
if errors.Is(err, chatd.ErrChildUnarchiveParentArchived) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Cannot unarchive a child chat while its parent is archived. Unarchive the parent chat to cascade.",
})
return
}
action := "archive"
if !archived {
action = "unarchive"
@@ -2046,6 +2122,38 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// writeChildUnarchiveGuard returns a 400 early when a child unarchive
// request obviously races an archived parent. The durable invariant
// is enforced atomically in chatd.Server.UnarchiveChat; this guard
// just surfaces the error before we take any locks.
//
// Returns true when a response has been written.
func (api *API) writeChildUnarchiveGuard(
ctx context.Context,
rw http.ResponseWriter,
chat database.Chat,
) bool {
parent, err := api.Database.GetChatByID(ctx, chat.ParentChatID.UUID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return true
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to load parent chat.",
Detail: err.Error(),
})
return true
}
if parent.Archived {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Cannot unarchive a child chat while its parent is archived. Unarchive the parent chat to cascade.",
})
return true
}
return false
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+421 -18
View File
@@ -1251,6 +1251,158 @@ func TestListChats(t *testing.T) {
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
"pin_order=2 chat should be second")
})
t.Run("ChildChatsEmbeddedNotStandalone", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create a parent chat via the API.
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "root chat with children",
},
},
})
require.NoError(t, err)
// Insert child chats directly via the database.
child1, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child one",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
child2, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child two",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
// Also create a standalone root chat to verify it still appears.
standalone, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "standalone root chat",
},
},
})
require.NoError(t, err)
chats, err := client.ListChats(ctx, nil)
require.NoError(t, err)
// Only root chats should appear at the top level.
rootIDs := make(map[uuid.UUID]struct{}, len(chats))
for _, c := range chats {
rootIDs[c.ID] = struct{}{}
require.Nil(t, c.ParentChatID, "top-level entry should have no parent")
}
require.Contains(t, rootIDs, parentChat.ID)
require.Contains(t, rootIDs, standalone.ID)
require.NotContains(t, rootIDs, child1.ID, "child1 should not appear at top level")
require.NotContains(t, rootIDs, child2.ID, "child2 should not appear at top level")
// Find the parent in the list and verify children are embedded.
var parent codersdk.Chat
for _, c := range chats {
if c.ID == parentChat.ID {
parent = c
break
}
}
require.Len(t, parent.Children, 2, "parent should embed 2 children")
// Children should be ordered by created_at ASC.
childIDs := []uuid.UUID{parent.Children[0].ID, parent.Children[1].ID}
require.Equal(t, child1.ID, childIDs[0])
require.Equal(t, child2.ID, childIDs[1])
// Verify each child has correct parent/root references.
for _, child := range parent.Children {
require.NotNil(t, child.ParentChatID)
require.Equal(t, parentChat.ID, *child.ParentChatID)
require.NotNil(t, child.RootChatID)
require.Equal(t, parentChat.ID, *child.RootChatID)
}
// Standalone root chat should have an empty children slice.
for _, c := range chats {
if c.ID == standalone.ID {
require.NotNil(t, c.Children)
require.Empty(t, c.Children)
break
}
}
})
t.Run("PaginationCountsOnlyRootChats", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create 3 root chats, each with 2 children.
for i := range 3 {
parent, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("parent %d", i),
},
},
})
require.NoError(t, err)
for j := range 2 {
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: fmt.Sprintf("child %d-%d", i, j),
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
})
require.NoError(t, err)
}
}
// Request with limit=2: should get 2 root chats (not 2 of
// the 9 total chats). Each root should have its children.
chats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: 2},
})
require.NoError(t, err)
require.Len(t, chats, 2, "limit should apply to root chats only")
for _, c := range chats {
require.Nil(t, c.ParentChatID)
require.Len(t, c.Children, 2, "each root should embed its 2 children")
}
})
}
func TestListChatModels(t *testing.T) {
@@ -3692,6 +3844,65 @@ func TestGetChat(t *testing.T) {
require.NoError(t, err)
require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs)
})
t.Run("GetChatEmbedsChildren", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "parent for getChat",
},
},
})
require.NoError(t, err)
child, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child for getChat",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
// Fetching the root chat should embed its children.
result, err := client.GetChat(ctx, parentChat.ID)
require.NoError(t, err)
require.Len(t, result.Children, 1)
require.Equal(t, child.ID, result.Children[0].ID)
require.NotNil(t, result.Children[0].ParentChatID)
require.Equal(t, parentChat.ID, *result.Children[0].ParentChatID)
// Fetching a child chat should not have children.
childResult, err := client.GetChat(ctx, child.ID)
require.NoError(t, err)
require.NotNil(t, childResult.Children)
require.Empty(t, childResult.Children)
// An archived root should still embed its cascaded
// archived children (guards against the filter getting
// hardcoded to false).
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
archivedResult, err := client.GetChat(ctx, parentChat.ID)
require.NoError(t, err)
require.True(t, archivedResult.Archived, "root should be archived")
require.Len(t, archivedResult.Children, 1, "archived root should embed its archived child")
require.Equal(t, child.ID, archivedResult.Children[0].ID)
require.True(t, archivedResult.Children[0].Archived, "embedded child should be archived")
})
}
func TestPatchChat(t *testing.T) {
@@ -4083,6 +4294,100 @@ func TestArchiveChat(t *testing.T) {
dbChild2, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child2.ID)
require.NoError(t, err)
require.True(t, dbChild2.Archived, "child2 should be archived")
// archived:true should return the parent with both
// cascaded children embedded.
archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Query: "archived:true",
})
require.NoError(t, err)
var foundParent *codersdk.Chat
for _, chat := range archivedChats {
if chat.ID == parentChat.ID {
foundParent = &chat
break
}
}
require.NotNil(t, foundParent, "parent should appear in archived list")
require.True(t, foundParent.Archived, "parent should be archived")
require.Len(t, foundParent.Children, 2, "both archived children should be embedded under the archived parent")
childIDs := map[uuid.UUID]bool{}
for _, child := range foundParent.Children {
require.True(t, child.Archived, "embedded child should be archived")
childIDs[child.ID] = true
}
require.True(t, childIDs[child1.ID], "child1 should be embedded under archived parent")
require.True(t, childIDs[child2.ID], "child2 should be embedded under archived parent")
})
t.Run("AllowsChildChatArchiveIndividually", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create a parent chat via the API.
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "parent",
},
},
})
require.NoError(t, err)
// Insert a child chat directly via the database.
child, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
// Individual child archive is permitted and leaves the
// parent active; the invariant is one-way.
err = client.UpdateChat(ctx, child.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
dbChild, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child.ID)
require.NoError(t, err)
require.True(t, dbChild.Archived, "child should be archived")
dbParent, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), parentChat.ID)
require.NoError(t, err)
require.False(t, dbParent.Archived, "parent should stay active")
// Archived child is hidden under an active parent.
activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "archived:false"})
require.NoError(t, err)
var activeParent *codersdk.Chat
for i := range activeChats {
if activeChats[i].ID == parentChat.ID {
activeParent = &activeChats[i]
break
}
}
require.NotNil(t, activeParent, "parent should appear in active list")
for _, c := range activeParent.Children {
require.NotEqual(t, child.ID, c.ID, "archived child must not appear under active parent")
}
// Nor does the child surface in the archived list (only
// roots paginate there).
archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{Query: "archived:true"})
require.NoError(t, err)
for _, c := range archivedChats {
require.NotEqual(t, child.ID, c.ID, "archived child should not surface as a root in archived list")
}
})
}
@@ -4192,25 +4497,28 @@ func TestUnarchiveChat(t *testing.T) {
})
require.NoError(t, err)
var foundParent bool
var foundChild1 bool
var foundChild2 bool
// Children no longer appear as top-level entries.
// They are embedded inside the parent's Children field.
var foundParent *codersdk.Chat
for _, chat := range activeChats {
switch chat.ID {
case parentChat.ID:
foundParent = true
require.False(t, chat.Archived)
case child1.ID:
foundChild1 = true
require.False(t, chat.Archived)
case child2.ID:
foundChild2 = true
require.False(t, chat.Archived)
require.NotEqual(t, child1.ID, chat.ID, "child1 should not appear at top level")
require.NotEqual(t, child2.ID, chat.ID, "child2 should not appear at top level")
if chat.ID == parentChat.ID {
foundParent = &chat
}
}
require.True(t, foundParent, "parent should be listed as active")
require.True(t, foundChild1, "child1 should be listed as active")
require.True(t, foundChild2, "child2 should be listed as active")
require.NotNil(t, foundParent, "parent should be listed as active")
require.False(t, foundParent.Archived)
// Verify children are embedded and unarchived.
require.Len(t, foundParent.Children, 2)
childIDs := map[uuid.UUID]bool{}
for _, child := range foundParent.Children {
require.False(t, child.Archived)
childIDs[child.ID] = true
}
require.True(t, childIDs[child1.ID], "child1 should be embedded")
require.True(t, childIDs[child2.ID], "child2 should be embedded")
archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Query: "archived:true",
@@ -4218,8 +4526,6 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
for _, chat := range archivedChats {
require.NotEqual(t, parentChat.ID, chat.ID, "parent should not remain archived")
require.NotEqual(t, child1.ID, chat.ID, "child1 should not remain archived")
require.NotEqual(t, child2.ID, chat.ID, "child2 should not remain archived")
}
dbParent, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), parentChat.ID)
@@ -4258,6 +4564,103 @@ func TestUnarchiveChat(t *testing.T) {
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("RejectsChildChatWhenParentArchived", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create a parent chat via the API.
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "parent",
},
},
})
require.NoError(t, err)
// Insert a child directly via the database, then archive the
// parent so the whole family is archived (cascade).
child, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
// Unarchiving the child while the parent stays archived
// must be rejected. Otherwise the child becomes a ghost
// (active list excludes the parent, archived list's child
// query filters archived=true so the now-unarchived child
// is also excluded).
err = client.UpdateChat(ctx, child.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
requireSDKError(t, err, http.StatusBadRequest)
dbChild, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child.ID)
require.NoError(t, err)
require.True(t, dbChild.Archived, "child should still be archived")
})
t.Run("AllowsChildChatWhenParentNotArchived", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "parent",
},
},
})
require.NoError(t, err)
// Simulate legacy lone-archived child (from before the
// child-archive gate existed) by inserting it directly
// with archived=true while the parent is not archived.
child, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OrganizationID: user.OrganizationID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "legacy child",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
_, err = db.ArchiveChatByID(dbauthz.AsSystemRestricted(ctx), child.ID)
require.NoError(t, err)
// Unarchiving the child is permitted because the parent is
// already active; this is the recovery path for legacy
// data.
err = client.UpdateChat(ctx, child.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
dbChild, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child.ID)
require.NoError(t, err)
require.False(t, dbChild.Archived, "child should be unarchived")
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
-2
View File
@@ -2046,7 +2046,6 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
WorkspaceID: workspace.ID,
OwnerID: workspace.OwnerID,
Branch: gitRef.Branch,
Origin: gitRef.RemoteOrigin,
ChatID: gitRef.ChatID,
@@ -2201,7 +2200,6 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
WorkspaceID: workspace.ID,
OwnerID: workspace.OwnerID,
Branch: gitRef.Branch,
Origin: gitRef.RemoteOrigin,
ChatID: gitRef.ChatID,
+58 -9
View File
@@ -1462,20 +1462,69 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
return nil
}
// UnarchiveChat unarchives a chat family and publishes created events for
// each affected chat so watching clients see every chat that reappeared.
// ErrChildUnarchiveParentArchived is returned by UnarchiveChat when a
// child unarchive is rejected because the parent is still archived.
// The patchChat handler maps this to a 400 response.
var ErrChildUnarchiveParentArchived = xerrors.New(
"cannot unarchive child chat while parent is archived",
)
// UnarchiveChat unarchives a chat family and broadcasts created events.
// Root chats cascade through UnarchiveChatByID. Child chats run under
// a row-level lock on the child (GetChatByIDForUpdate) with an
// in-transaction re-read of the parent, returning
// ErrChildUnarchiveParentArchived when the parent is archived and a
// no-op when the child is already active.
//
// The child is locked before the parent is read to avoid deadlocking
// with a concurrent ArchiveChatByID cascade, which visits child rows
// before the parent.
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
return p.applyChatLifecycleTransition(
ctx,
chat.ID,
"unarchive",
codersdk.ChatWatchEventKindCreated,
p.db.UnarchiveChatByID,
)
if !chat.ParentChatID.Valid {
return p.applyChatLifecycleTransition(
ctx,
chat.ID,
"unarchive",
codersdk.ChatWatchEventKindCreated,
p.db.UnarchiveChatByID,
)
}
var updated []database.Chat
if err := p.db.InTx(func(tx database.Store) error {
locked, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("lock child for unarchive: %w", err)
}
if !locked.Archived {
// Already unarchived by a concurrent caller; idempotent no-op.
return nil
}
parent, err := tx.GetChatByID(ctx, chat.ParentChatID.UUID)
if err != nil {
return xerrors.Errorf("load parent chat: %w", err)
}
if parent.Archived {
return ErrChildUnarchiveParentArchived
}
updated, err = tx.UnarchiveChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("unarchive child chat: %w", err)
}
return nil
}, nil); err != nil {
if errors.Is(err, ErrChildUnarchiveParentArchived) {
return ErrChildUnarchiveParentArchived
}
return err
}
p.publishChatPubsubEvents(updated, codersdk.ChatWatchEventKindCreated)
return nil
}
func (p *Server) applyChatLifecycleTransition(
+137 -10
View File
@@ -644,11 +644,19 @@ func TestExploreSubagentIsReadOnly(t *testing.T) {
require.True(t, requestHasSystemSubstring(childRequests[0], "You are in Explore Mode as a delegated sub-agent."))
require.False(t, requestHasSystemSubstring(rootRequests[0], "You are in Explore Mode as a delegated sub-agent."))
allChats, err := db.GetChats(dbauthz.AsChatd(ctx), database.GetChatsParams{OwnerID: user.UserID})
rootChats, err := db.GetChats(dbauthz.AsChatd(ctx), database.GetChatsParams{OwnerID: user.UserID})
require.NoError(t, err)
rootIDs := make([]uuid.UUID, 0, len(rootChats))
for _, root := range rootChats {
rootIDs = append(rootIDs, root.Chat.ID)
}
childRows, err := db.GetChildChatsByParentIDs(dbauthz.AsChatd(ctx), database.GetChildChatsByParentIDsParams{
ParentIds: rootIDs,
})
require.NoError(t, err)
var exploreChildren []database.Chat
for _, candidate := range allChats {
if candidate.Chat.ParentChatID.Valid && candidate.Chat.Mode.Valid && candidate.Chat.Mode.ChatMode == database.ChatModeExplore {
for _, candidate := range childRows {
if candidate.Chat.Mode.Valid && candidate.Chat.Mode.ChatMode == database.ChatModeExplore {
exploreChildren = append(exploreChildren, candidate.Chat)
}
}
@@ -733,6 +741,127 @@ func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
require.Zero(t, fromDB.PinOrder)
}
// TestUnarchiveChildChat covers the deterministic branches of the
// Server.UnarchiveChat child path: happy path, archived-parent reject,
// and already-active no-op.
func TestUnarchiveChildChat(t *testing.T) {
t.Parallel()
t.Run("ChildWithActiveParentUnarchives", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
parent, child := insertParentWithArchivedChild(ctx, t, db, user, org, model)
require.NoError(t, replica.UnarchiveChat(ctx, child))
dbChild, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.False(t, dbChild.Archived, "child should be unarchived")
dbParent, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
require.False(t, dbParent.Archived, "parent should stay active")
})
t.Run("ChildWithArchivedParentRejected", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
parent, child := insertParentWithArchivedChild(ctx, t, db, user, org, model)
_, err := db.ArchiveChatByID(ctx, parent.ID)
require.NoError(t, err)
err = replica.UnarchiveChat(ctx, child)
require.ErrorIs(t, err, chatd.ErrChildUnarchiveParentArchived)
dbChild, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, dbChild.Archived, "child should remain archived")
})
t.Run("AlreadyActiveChildNoOp", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
_, child := insertParentWithActiveChild(ctx, t, db, user, org, model)
require.NoError(t, replica.UnarchiveChat(ctx, child))
dbChild, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.False(t, dbChild.Archived, "child should stay active")
})
}
// insertParentWithActiveChild creates a parent chat and an active
// child chat linked to it. Both are returned in their initial
// (active) state.
func insertParentWithActiveChild(
ctx context.Context,
t *testing.T,
db database.Store,
user database.User,
org database.Organization,
model database.ChatModelConfig,
) (parent database.Chat, child database.Chat) {
t.Helper()
var err error
parent, err = db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
OwnerID: user.ID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
LastModelConfigID: model.ID,
Title: "parent",
})
require.NoError(t, err)
child, err = db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
OwnerID: user.ID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
LastModelConfigID: model.ID,
Title: "child",
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
})
require.NoError(t, err)
return parent, child
}
// insertParentWithArchivedChild creates an active parent and an
// individually-archived child. The returned child reflects its
// current (archived) state in the DB.
func insertParentWithArchivedChild(
ctx context.Context,
t *testing.T,
db database.Store,
user database.User,
org database.Organization,
model database.ChatModelConfig,
) (parent database.Chat, child database.Chat) {
t.Helper()
parent, child = insertParentWithActiveChild(ctx, t, db, user, org, model)
_, err := db.ArchiveChatByID(ctx, child.ID)
require.NoError(t, err)
child, err = db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
return parent, child
}
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
t.Parallel()
@@ -4976,15 +5105,13 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
// 6. Verify the child chat has Mode = computer_use in
// the DB.
allChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
childRows, err := db.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
ParentIds: []uuid.UUID{chat.ID},
})
require.NoError(t, err)
var children []database.Chat
for _, c := range allChats {
if c.Chat.ParentChatID.Valid && c.Chat.ParentChatID.UUID == chat.ID {
children = append(children, c.Chat)
}
children := make([]database.Chat, 0, len(childRows))
for _, row := range childRows {
children = append(children, row.Chat)
}
require.Len(t, children, 1)
require.True(t, children[0].Mode.Valid)
+9 -29
View File
@@ -66,9 +66,9 @@ type Store interface {
UpsertChatDiffStatusReference(
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
) (database.ChatDiffStatus, error)
GetChats(
ctx context.Context, arg database.GetChatsParams,
) ([]database.GetChatsRow, error)
GetChatsByWorkspaceIDs(
ctx context.Context, ids []uuid.UUID,
) ([]database.Chat, error)
}
// EventPublisher notifies the frontend of diff status changes.
@@ -277,7 +277,6 @@ func (w *Worker) tick(ctx context.Context) {
// MarkStaleParams holds the arguments for Worker.MarkStale.
type MarkStaleParams struct {
WorkspaceID uuid.UUID
OwnerID uuid.UUID
Branch string
Origin string
// ChatID, when set, targets a single chat instead of
@@ -306,9 +305,11 @@ func (w *Worker) MarkStale(ctx context.Context, p MarkStaleParams) {
return
}
chatRows, err := w.store.GetChats(ctx, database.GetChatsParams{
OwnerID: p.OwnerID,
})
// Broadcast path: scope by workspace. GetChatsByWorkspaceIDs
// filters archived=false, which is intentional: archived
// chats aren't in the active sidebar and don't need refreshed
// git refs.
chats, err := w.store.GetChatsByWorkspaceIDs(ctx, []uuid.UUID{p.WorkspaceID})
if err != nil {
w.logger.Warn(ctx, "list chats for git ref storage",
slog.F("workspace_id", p.WorkspaceID),
@@ -316,12 +317,7 @@ func (w *Worker) MarkStale(ctx context.Context, p MarkStaleParams) {
return
}
chats := make([]database.Chat, len(chatRows))
for i, row := range chatRows {
chats[i] = row.Chat
}
for _, chat := range filterChatsByWorkspaceID(chats, p.WorkspaceID) {
for _, chat := range chats {
w.markStaleSingle(ctx, chat.ID, p.Branch, p.Origin)
}
}
@@ -403,19 +399,3 @@ func (w *Worker) RefreshChat(
return &upserted, nil
}
// filterChatsByWorkspaceID returns only chats associated with
// the given workspace.
func filterChatsByWorkspaceID(
chats []database.Chat,
workspaceID uuid.UUID,
) []database.Chat {
filtered := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
continue
}
filtered = append(filtered, chat)
}
return filtered
}
+23 -36
View File
@@ -606,7 +606,6 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
ownerID := uuid.New()
chat1 := uuid.New()
chat2 := uuid.New()
chatOther := uuid.New()
var mu sync.Mutex
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
@@ -615,13 +614,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.GetChatsRow{
{Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
{Chat: database.Chat{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
{Chat: database.Chat{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}},
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, ids []uuid.UUID) ([]database.Chat, error) {
require.Equal(t, []uuid.UUID{workspaceID}, ids)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
@@ -646,7 +644,6 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "feature",
Origin: "https://github.com/owner/repo",
})
@@ -672,16 +669,12 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
Return([]database.GetChatsRow{
{Chat: database.Chat{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}},
{Chat: database.Chat{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}},
}, nil)
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).
Return(nil, nil)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
@@ -690,7 +683,6 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "main",
Origin: "https://github.com/x/y",
})
@@ -710,10 +702,10 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
Return([]database.GetChatsRow{
{Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
{Chat: database.Chat{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
}, nil)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
@@ -735,7 +727,6 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "dev",
Origin: "https://github.com/a/b",
})
@@ -743,14 +734,14 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
func TestWorker_MarkStale_GetChatsByWorkspaceIDsFails(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("db error"))
mClock := quartz.NewMock(t)
@@ -760,7 +751,6 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: "main",
Origin: "https://github.com/x/y",
})
@@ -817,7 +807,6 @@ func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: tc.branch,
Origin: tc.origin,
})
@@ -838,8 +827,8 @@ func TestWorker_MarkStale_WithChatID(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
// GetChats should NOT be called when a specific chat ID is provided.
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).Times(0)
// GetChatsByWorkspaceIDs should NOT be called when a specific chat ID is provided.
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).Times(0)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
@@ -862,7 +851,6 @@ func TestWorker_MarkStale_WithChatID(t *testing.T) {
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: "my-branch",
Origin: "https://github.com/org/repo",
ChatID: targetChat,
@@ -897,13 +885,13 @@ func TestWorker_MarkStale_NilChatID_Broadcasts(t *testing.T) {
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
// GetChats IS called because a nil ChatID triggers the
// workspace-wide broadcast path.
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.GetChatsRow{
{Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
// Broadcast path: GetChatsByWorkspaceIDs scopes the query to
// the workspace directly; no post-filtering needed.
store.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, ids []uuid.UUID) ([]database.Chat, error) {
require.Equal(t, []uuid.UUID{workspaceID}, ids)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
@@ -928,7 +916,6 @@ func TestWorker_MarkStale_NilChatID_Broadcasts(t *testing.T) {
// Zero-value ChatID (uuid.Nil) triggers broadcast.
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "main",
Origin: "https://github.com/org/repo",
})
+6
View File
@@ -95,6 +95,12 @@ type Chat struct {
LastInjectedContext []ChatMessagePart `json:"last_injected_context,omitempty"`
Warnings []string `json:"warnings,omitempty"`
ClientType ChatClientType `json:"client_type"`
// Children holds child (subagent) chats nested under this root
// chat. Always initialized to an empty slice so the JSON field
// is present as []. Child chats cannot create their own
// subagents, so nesting depth is capped at 1 and this slice is
// always empty for child chats.
Children []Chat `json:"children"`
}
// ChatFileMetadata contains lightweight metadata about a file
+179
View File
@@ -4,6 +4,7 @@ import { API } from "#/api/api";
import type * as TypesGen from "#/api/typesGenerated";
import { buildOptimisticEditedMessage } from "./chatMessageEdits";
import {
appendChildToParentInCache,
archiveChat,
cancelChatListRefetches,
chatCostSummary,
@@ -23,10 +24,12 @@ import {
pinChat,
promoteChatQueuedMessage,
regenerateChatTitle,
removeChildFromParentInCache,
reorderPinnedChat,
unarchiveChat,
unpinChat,
updateChatPlanMode,
updateChildInParentCache,
updateInfiniteChatsCache,
} from "./chats";
@@ -95,6 +98,7 @@ const makeChat = (
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
...overrides,
});
@@ -263,6 +267,29 @@ describe("archiveChat optimistic update", () => {
expect(cachedChat?.archived).toBe(true);
});
it("strips an individually-archived child from its parent's embedded children", async () => {
const queryClient = createTestQueryClient();
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const sibling = makeChat("child-2", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const parent = makeChat("parent-1", { children: [child, sibling] });
seedInfiniteChats(queryClient, [parent]);
vi.mocked(API.experimental.updateChat).mockResolvedValue();
const mutation = archiveChat(queryClient);
await mutation.onMutate("child-1");
const result = readInfiniteChats(queryClient);
expect(result?.[0].children).toHaveLength(1);
expect(result?.[0].children?.[0].id).toBe("child-2");
});
it("rolls back the chats list on error by invalidating", async () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
@@ -1675,3 +1702,155 @@ describe("mutation onMutate cancels pagination fetches", () => {
expect(chat?.archived).toBe(true);
});
});
describe("appendChildToParentInCache", () => {
it("appends the child to the matching parent's children array", () => {
const queryClient = createTestQueryClient();
const parent = makeChat("parent-1");
seedInfiniteChats(queryClient, [parent]);
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
appendChildToParentInCache(queryClient, child, "parent-1");
const result = readInfiniteChats(queryClient);
expect(result).toHaveLength(1);
expect(result?.[0].children).toHaveLength(1);
expect(result?.[0].children?.[0].id).toBe("child-1");
});
it("silently drops the child when the parent is not in any page", () => {
const queryClient = createTestQueryClient();
const other = makeChat("other-root");
seedInfiniteChats(queryClient, [other]);
const child = makeChat("orphan-child", {
parent_chat_id: "missing-parent",
root_chat_id: "missing-parent",
});
appendChildToParentInCache(queryClient, child, "missing-parent");
const result = readInfiniteChats(queryClient);
expect(result).toHaveLength(1);
expect(result?.[0].id).toBe("other-root");
expect(result?.[0].children).toHaveLength(0);
});
it("does not duplicate a child that already exists under the parent", () => {
const queryClient = createTestQueryClient();
const existingChild = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const parent = makeChat("parent-1", { children: [existingChild] });
seedInfiniteChats(queryClient, [parent]);
appendChildToParentInCache(queryClient, existingChild, "parent-1");
const result = readInfiniteChats(queryClient);
expect(result?.[0].children).toHaveLength(1);
});
});
describe("updateChildInParentCache", () => {
it("applies the updater to a child nested under its parent", () => {
const queryClient = createTestQueryClient();
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
title: "Original title",
});
const parent = makeChat("parent-1", { children: [child] });
seedInfiniteChats(queryClient, [parent]);
const found = updateChildInParentCache(
queryClient,
(c) => ({ ...c, title: "Updated title" }),
"child-1",
);
expect(found).toBe(true);
const result = readInfiniteChats(queryClient);
expect(result?.[0].children?.[0].title).toBe("Updated title");
});
it("returns false when the child is not present under any parent", () => {
const queryClient = createTestQueryClient();
const parent = makeChat("parent-1");
seedInfiniteChats(queryClient, [parent]);
const found = updateChildInParentCache(
queryClient,
(c) => ({ ...c, title: "Never applied" }),
"missing-child",
);
expect(found).toBe(false);
});
it("preserves the same reference when the updater returns the child unchanged", () => {
const queryClient = createTestQueryClient();
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const parent = makeChat("parent-1", { children: [child] });
seedInfiniteChats(queryClient, [parent]);
const before = readInfiniteChats(queryClient)?.[0];
const found = updateChildInParentCache(queryClient, (c) => c, "child-1");
const after = readInfiniteChats(queryClient)?.[0];
expect(found).toBe(false);
expect(after).toBe(before);
});
});
describe("removeChildFromParentInCache", () => {
it("removes the child from its parent's children array", () => {
const queryClient = createTestQueryClient();
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const sibling = makeChat("child-2", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const parent = makeChat("parent-1", { children: [child, sibling] });
seedInfiniteChats(queryClient, [parent]);
const found = removeChildFromParentInCache(queryClient, "child-1");
expect(found).toBe(true);
const result = readInfiniteChats(queryClient);
expect(result?.[0].children).toHaveLength(1);
expect(result?.[0].children?.[0].id).toBe("child-2");
});
it("returns false when no parent embeds the given child", () => {
const queryClient = createTestQueryClient();
const parent = makeChat("parent-1");
seedInfiniteChats(queryClient, [parent]);
const found = removeChildFromParentInCache(queryClient, "missing-child");
expect(found).toBe(false);
});
it("preserves the parent reference when the child is not found", () => {
const queryClient = createTestQueryClient();
const child = makeChat("child-1", {
parent_chat_id: "parent-1",
root_chat_id: "parent-1",
});
const parent = makeChat("parent-1", { children: [child] });
seedInfiniteChats(queryClient, [parent]);
const before = readInfiniteChats(queryClient)?.[0];
removeChildFromParentInCache(queryClient, "missing-child");
const after = readInfiniteChats(queryClient)?.[0];
expect(after).toBe(before);
});
});
+88
View File
@@ -102,6 +102,90 @@ export const readInfiniteChatsCache = (
return undefined;
};
/**
* Appends a child chat to its parent's `children` array across all
* infinite chat query caches. If the parent is not in any loaded page,
* the child is silently dropped (it will appear when the parent loads).
*/
export const appendChildToParentInCache = (
queryClient: QueryClient,
child: TypesGen.Chat,
parentId: string,
) => {
updateInfiniteChatsCache(queryClient, (chats) => {
let changed = false;
const next = chats.map((c) => {
if (c.id !== parentId) return c;
// Avoid duplicates.
if (c.children?.some((ch) => ch.id === child.id)) return c;
changed = true;
return { ...c, children: [...(c.children ?? []), child] };
});
return changed ? next : chats;
});
};
/**
* Updates a child chat within its parent's `children` array across all
* infinite chat query caches. Returns true if the child was found and
* updated, false otherwise.
*/
export const updateChildInParentCache = (
queryClient: QueryClient,
updater: (child: TypesGen.Chat) => TypesGen.Chat,
childId: string,
) => {
let found = false;
updateInfiniteChatsCache(queryClient, (chats) => {
let changed = false;
const next = chats.map((c) => {
if (!c.children?.length) return c;
let childChanged = false;
const nextChildren = c.children.map((ch) => {
if (ch.id !== childId) return ch;
const updated = updater(ch);
if (updated !== ch) {
childChanged = true;
found = true;
}
return updated;
});
if (!childChanged) return c;
changed = true;
return { ...c, children: nextChildren };
});
return changed ? next : chats;
});
return found;
};
/**
* Removes a child chat from its parent's `children` array across all
* infinite chat query caches. Returns true if the child was found and
* removed, false otherwise. Used when a child is archived individually
* (the sidebar hides children whose archive state differs from the
* parent) and when a `deleted` pubsub event arrives for a child chat.
*/
export const removeChildFromParentInCache = (
queryClient: QueryClient,
childId: string,
) => {
let found = false;
updateInfiniteChatsCache(queryClient, (chats) => {
let changed = false;
const next = chats.map((c) => {
if (!c.children?.length) return c;
const filtered = c.children.filter((ch) => ch.id !== childId);
if (filtered.length === c.children.length) return c;
found = true;
changed = true;
return { ...c, children: filtered };
});
return changed ? next : chats;
});
return found;
};
const getNextOptimisticPinOrder = (queryClient: QueryClient): number => {
let maxPinOrder = 0;
const queries = queryClient.getQueriesData<
@@ -309,11 +393,15 @@ export const archiveChat = (queryClient: QueryClient) => ({
const previousChat = queryClient.getQueryData<TypesGen.Chat>(
chatKey(chatId),
);
// Flip archived flag in the flat root list; strip the
// chat from any parent's embedded children (individual
// child archive).
updateInfiniteChatsCache(queryClient, (chats) =>
chats.map((chat) =>
chat.id === chatId ? { ...chat, archived: true } : chat,
),
);
removeChildFromParentInCache(queryClient, chatId);
if (previousChat) {
queryClient.setQueryData<TypesGen.Chat>(chatKey(chatId), {
...previousChat,
+8
View File
@@ -1261,6 +1261,14 @@ export interface Chat {
readonly last_injected_context?: readonly ChatMessagePart[];
readonly warnings?: readonly string[];
readonly client_type: ChatClientType;
/**
* Children holds child (subagent) chats nested under this root
* chat. Always initialized to an empty slice so the JSON field
* is present as []. Child chats cannot create their own
* subagents, so nesting depth is capped at 1 and this slice is
* always empty for child chats.
*/
readonly children: readonly Chat[];
}
// From codersdk/chats.go
@@ -134,6 +134,7 @@ const baseChatFields = {
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
} as const;
// ---------------------------------------------------------------------------
@@ -58,6 +58,7 @@ const buildChat = (overrides: Partial<TypesGen.Chat> = {}): TypesGen.Chat => ({
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
...overrides,
});
+73 -46
View File
@@ -10,6 +10,7 @@ import { toast } from "sonner";
import { API, watchChats } from "#/api/api";
import { getErrorMessage } from "#/api/errors";
import {
appendChildToParentInCache,
archiveChat,
cancelChatListRefetches,
chatDiffContentsKey,
@@ -23,9 +24,11 @@ import {
prependToInfiniteChatsCache,
readInfiniteChatsCache,
regenerateChatTitle,
removeChildFromParentInCache,
reorderPinnedChat,
unarchiveChat,
unpinChat,
updateChildInParentCache,
updateInfiniteChatsCache,
} from "#/api/queries/chats";
import { workspaceById } from "#/api/queries/workspaces";
@@ -502,19 +505,22 @@ const AgentsPage: FC = () => {
}
if (chatEvent.kind === "deleted") {
// Drop the chat from the flat root list (root or
// cascade via root_chat_id) and from any parent's
// embedded children (individual child archive).
updateInfiniteChatsCache(queryClient, (chats) =>
chats.filter(
(c) =>
c.id !== updatedChat.id && c.root_chat_id !== updatedChat.id,
),
);
removeChildFromParentInCache(queryClient, updatedChat.id);
queryClient.removeQueries({
queryKey: chatKey(updatedChat.id),
exact: true,
});
return;
}
if (chatEvent.kind === "diff_status_change") {
// Only refetch the diff file contents — the chat's
// diff_status field is already written into the
@@ -560,59 +566,80 @@ const AgentsPage: FC = () => {
// page, so a naive prepend would duplicate the
// chat into every loaded page.
if (chatEvent.kind === "created") {
prependToInfiniteChatsCache(queryClient, updatedChat);
if (updatedChat.parent_chat_id) {
// Child chat: append to its parent's children
// array. If the parent is not in any loaded
// page, the child is silently dropped.
appendChildToParentInCache(
queryClient,
updatedChat,
updatedChat.parent_chat_id,
);
} else {
prependToInfiniteChatsCache(queryClient, updatedChat);
}
} else {
// Build a field updater shared between root and
// child cache update paths.
const applyFields = (c: TypesGen.Chat): TypesGen.Chat => {
const nextStatus = isStatusEvent ? updatedChat.status : c.status;
const nextTitle = isTitleEvent ? updatedChat.title : c.title;
const nextDiffStatus = isDiffStatusEvent
? updatedChat.diff_status
: c.diff_status;
const nextWorkspaceId =
updatedChat.workspace_id ?? c.workspace_id;
const nextBuildId = updatedChat.build_id ?? c.build_id;
const nextUpdatedAt =
c.updated_at > updatedChat.updated_at
? c.updated_at
: updatedChat.updated_at;
// The server's pubsub path does not compute
// has_unread (it always sends false). For
// status_change events on non-active chats,
// optimistically mark as unread since the
// assistant produced new output.
const nextHasUnread =
isStatusEvent && updatedChat.id !== activeChatIDRef.current
? true
: c.has_unread;
if (
nextStatus === c.status &&
nextTitle === c.title &&
diffStatusEqual(nextDiffStatus, c.diff_status) &&
nextWorkspaceId === c.workspace_id &&
nextBuildId === c.build_id &&
nextHasUnread === c.has_unread
) {
return c;
}
return {
...c,
status: nextStatus,
title: nextTitle,
diff_status: nextDiffStatus,
workspace_id: nextWorkspaceId,
build_id: nextBuildId,
updated_at: nextUpdatedAt,
has_unread: nextHasUnread,
};
};
// Try root-level update first.
updateInfiniteChatsCache(queryClient, (chats) => {
let didUpdate = false;
const nextChats = chats.map((c) => {
if (c.id !== updatedChat.id) return c;
const nextStatus = isStatusEvent
? updatedChat.status
: c.status;
const nextTitle = isTitleEvent ? updatedChat.title : c.title;
const nextDiffStatus = isDiffStatusEvent
? updatedChat.diff_status
: c.diff_status;
const nextWorkspaceId =
updatedChat.workspace_id ?? c.workspace_id;
const nextBuildId = updatedChat.build_id ?? c.build_id;
const nextUpdatedAt =
c.updated_at > updatedChat.updated_at
? c.updated_at
: updatedChat.updated_at;
// The server's pubsub path does not compute
// has_unread (it always sends false). For
// status_change events on non-active chats,
// optimistically mark as unread since the
// assistant produced new output.
const nextHasUnread =
isStatusEvent && updatedChat.id !== activeChatIDRef.current
? true
: c.has_unread;
if (
nextStatus === c.status &&
nextTitle === c.title &&
diffStatusEqual(nextDiffStatus, c.diff_status) &&
nextWorkspaceId === c.workspace_id &&
nextBuildId === c.build_id &&
nextHasUnread === c.has_unread
) {
return c;
}
didUpdate = true;
return {
...c,
status: nextStatus,
title: nextTitle,
diff_status: nextDiffStatus,
workspace_id: nextWorkspaceId,
build_id: nextBuildId,
updated_at: nextUpdatedAt,
has_unread: nextHasUnread,
};
const result = applyFields(c);
if (result !== c) didUpdate = true;
return result;
});
return didUpdate ? nextChats : chats;
});
// Also update inside parent's children array
// in case the event targets a child chat.
updateChildInParentCache(queryClient, applyFields, updatedChat.id);
}
queryClient.setQueryData<TypesGen.Chat | undefined>(
chatKey(updatedChat.id),
@@ -143,6 +143,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
...overrides,
});
@@ -219,6 +219,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
});
const makeMessage = (
@@ -63,6 +63,7 @@ export const WithParentChat: Story = {
pin_order: 0,
has_unread: false,
client_type: "ui",
children: [],
},
},
};
@@ -54,6 +54,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
has_unread: false,
client_type: "ui",
last_error: null,
children: [],
...overrides,
});
@@ -110,13 +111,18 @@ type Story = StoryObj<typeof AgentsSidebar>;
export const RunningDelegatedChat: Story = {
args: {
chats: [
buildChat({ id: "root-1", title: "Root agent" }),
buildChat({
id: "child-running",
title: "Running child",
status: "running",
parent_chat_id: "root-1",
root_chat_id: "root-1",
id: "root-1",
title: "Root agent",
children: [
buildChat({
id: "child-running",
title: "Running child",
status: "running",
parent_chat_id: "root-1",
root_chat_id: "root-1",
}),
],
}),
],
},
@@ -140,13 +146,18 @@ export const RunningDelegatedChat: Story = {
export const PendingDelegatedChat: Story = {
args: {
chats: [
buildChat({ id: "root-pending", title: "Root agent" }),
buildChat({
id: "child-pending",
title: "Pending child",
status: "pending",
parent_chat_id: "root-pending",
root_chat_id: "root-pending",
id: "root-pending",
title: "Root agent",
children: [
buildChat({
id: "child-pending",
title: "Pending child",
status: "pending",
parent_chat_id: "root-pending",
root_chat_id: "root-pending",
}),
],
}),
],
},
@@ -170,12 +181,17 @@ export const PendingDelegatedChat: Story = {
export const ExpandCollapse: Story = {
args: {
chats: [
buildChat({ id: "root-2", title: "Root for collapse" }),
buildChat({
id: "child-collapse",
title: "Nested child",
parent_chat_id: "root-2",
root_chat_id: "root-2",
id: "root-2",
title: "Root for collapse",
children: [
buildChat({
id: "child-collapse",
title: "Nested child",
parent_chat_id: "root-2",
root_chat_id: "root-2",
}),
],
}),
],
},
@@ -212,12 +228,14 @@ export const RunningChatPreservesSpinner: Story = {
id: "root-running",
title: "Running root agent",
status: "running",
}),
buildChat({
id: "child-of-running",
title: "Child of running",
parent_chat_id: "root-running",
root_chat_id: "root-running",
children: [
buildChat({
id: "child-of-running",
title: "Child of running",
parent_chat_id: "root-running",
root_chat_id: "root-running",
}),
],
}),
],
},
@@ -257,13 +275,15 @@ export const IdleParentWithRunningChild: Story = {
id: "idle-parent",
title: "Idle parent agent",
status: "waiting",
}),
buildChat({
id: "running-child",
title: "Running sub-agent",
status: "running",
parent_chat_id: "idle-parent",
root_chat_id: "idle-parent",
children: [
buildChat({
id: "running-child",
title: "Running sub-agent",
status: "running",
parent_chat_id: "idle-parent",
root_chat_id: "idle-parent",
}),
],
}),
],
},
@@ -289,6 +309,10 @@ export const IdleParentWithRunningChild: Story = {
};
export const ActiveChatAncestryExpanded: Story = {
// This story uses the flat-sibling shape (parent_chat_id on each
// entry) rather than embedded children. It intentionally
// exercises the defensive fallback loop in buildChatTree for
// stale cache data from before root-only pagination landed.
args: {
chats: [
buildChat({ id: "root-active", title: "Active root" }),
@@ -332,6 +356,57 @@ export const ActiveChatAncestryExpanded: Story = {
},
};
export const MixedCacheDoesNotDuplicateChild: Story = {
// Simulates the rollout window where a stale cache entry for a child
// chat still appears as a flat sibling in the paginated list while
// the same child is also embedded under its parent. Without the
// guard in buildChatTree (`if (!parentById.has(chat.id))` around
// setting the parent link to undefined), the flat entry would
// overwrite the embedded parent link and the defensive fallback
// would re-add the child to its parent's children list, producing
// a React duplicate-key warning and double-render.
args: {
chats: [
buildChat({
id: "mixed-root",
title: "Mixed root",
children: [
buildChat({
id: "mixed-child",
title: "Mixed child",
parent_chat_id: "mixed-root",
root_chat_id: "mixed-root",
}),
],
}),
// Stale flat entry for the same child still present in the
// cache. It must not cause a duplicate render.
buildChat({
id: "mixed-child",
title: "Mixed child",
parent_chat_id: "mixed-root",
root_chat_id: "mixed-root",
}),
],
},
parameters: {
reactRouter: reactRouterParameters({
location: {
path: "/agents/mixed-child",
pathParams: { agentId: "mixed-child" },
},
routing: agentsRouting,
}),
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
await expect(canvas.getByText("Mixed root")).toBeInTheDocument();
await waitFor(() => {
expect(canvas.getAllByText("Mixed child")).toHaveLength(1);
});
},
};
// Use a fixed offset so the value always falls in the "Today" bucket
// without embedding a literal date that drifts across calendar days.
const recentTimestamp = new Date(Date.now() - 60_000).toISOString();
@@ -67,6 +67,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
last_error: null,
mcp_server_ids: [],
labels: {},
children: [],
...overrides,
});
@@ -157,6 +157,7 @@ const statusConfig = {
type ChatTree = {
readonly rootIds: readonly string[];
readonly chatById: ReadonlyMap<string, Chat>;
readonly childrenById: ReadonlyMap<string, readonly string[]>;
readonly parentById: ReadonlyMap<string, string | undefined>;
};
@@ -263,45 +264,54 @@ const getParentChatID = (chat: Chat): string | undefined => {
return asNonEmptyString(chat.parent_chat_id);
};
const getRootChatID = (chat: Chat): string | undefined => {
return asNonEmptyString(chat.root_chat_id);
};
const buildChatTree = (chats: readonly Chat[]): ChatTree => {
const orderById = new Map<string, number>();
const chatById = new Map<string, Chat>();
const parentById = new Map<string, string | undefined>();
const childrenById = new Map<string, string[]>();
for (const [index, chat] of chats.entries()) {
orderById.set(chat.id, index);
// The paginated list now contains only root chats. Children
// are embedded in each root's `children` field.
for (const chat of chats) {
chatById.set(chat.id, chat);
childrenById.set(chat.id, []);
}
for (const chat of chats) {
let parentID = getParentChatID(chat);
if (!parentID || parentID === chat.id || !chatById.has(parentID)) {
parentID = undefined;
// Guard against stale cache entries: if a flat child
// entry appears in `chats` after its embedded parent has
// already set its parent link, do not overwrite the link
// with `undefined`. Without this, the defensive fallback
// below re-adds the child to its parent's list, producing
// a duplicate key in React rendering.
if (!parentById.has(chat.id)) {
parentById.set(chat.id, undefined);
}
if (!parentID) {
const rootID = getRootChatID(chat);
if (rootID && rootID !== chat.id && chatById.has(rootID)) {
parentID = rootID;
if (chat.children) {
for (const child of chat.children) {
chatById.set(child.id, child);
parentById.set(child.id, chat.id);
childrenById.get(chat.id)?.push(child.id);
// Children cannot have their own children (depth
// capped at 1), but initialize the map entry for
// uniform lookup.
childrenById.set(child.id, []);
}
}
parentById.set(chat.id, parentID);
if (parentID) {
childrenById.get(parentID)?.push(chat.id);
}
}
for (const children of childrenById.values()) {
children.sort((leftID, rightID) => {
return (orderById.get(leftID) ?? 0) - (orderById.get(rightID) ?? 0);
});
// Defensive fallback for cached data during rollout: if any
// chat has a parent_chat_id that points to a chat in the list
// but was not embedded, build the link. This handles stale
// cache entries from before the backend change.
for (const chat of chats) {
const parentID = getParentChatID(chat);
if (
parentID &&
parentID !== chat.id &&
chatById.has(parentID) &&
!parentById.get(chat.id)
) {
parentById.set(chat.id, parentID);
childrenById.get(parentID)?.push(chat.id);
}
}
const rootIds = chats
@@ -310,6 +320,7 @@ const buildChatTree = (chats: readonly Chat[]): ChatTree => {
return {
rootIds,
chatById,
childrenById,
parentById,
};
@@ -325,10 +336,17 @@ const collectVisibleChatIDs = ({
readonly tree: ChatTree;
}): Set<string> => {
if (!search) {
return new Set(chats.map((chat) => chat.id));
const allIDs = new Set(chats.map((chat) => chat.id));
for (const chat of chats) {
for (const child of chat.children ?? []) {
allIDs.add(child.id);
}
}
return allIDs;
}
const matchedChatIDs = chats
const allChats = chats.flatMap((chat) => [chat, ...(chat.children ?? [])]);
const matchedChatIDs = allChats
.filter((chat) => chat.title.toLowerCase().includes(search))
.map((chat) => chat.id);
if (matchedChatIDs.length === 0) {
@@ -780,7 +798,7 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
const [expandedById, setExpandedById] = useState<Record<string, boolean>>({});
const chatTree = buildChatTree(chats);
const chatById = new Map(chats.map((chat) => [chat.id, chat] as const));
const chatById = chatTree.chatById;
const visibleChatIDs = collectVisibleChatIDs({
chats,
search: normalizedSearch,