mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): unarchive child chats with parents (#23761)
Unarchiving a root chat now restores descendant chats in the database and emits lifecycle events for every affected chat so passive sessions converge without a full refetch. This keeps archive and unarchive symmetric at both the data and watch-stream layers by returning the affected chat family from the database, using those post-update rows for chatd pubsub fanout, and covering descendant lifecycle delivery with a watch-level regression test. Closes #23666
This commit is contained in:
@@ -1570,13 +1570,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
|
||||
return q.db.AllUserIDs(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ArchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -5649,13 +5649,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.UnarchiveChatByID(ctx, id)
|
||||
}
|
||||
|
||||
@@ -392,14 +392,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("ArchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("UnarchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
|
||||
@@ -160,12 +160,12 @@ func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) (
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.ArchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.ArchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("ArchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ArchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
|
||||
@@ -4024,12 +4024,12 @@ func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXact
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnarchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.UnarchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UnarchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnarchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveTemplateVersion(ctx context.Context, arg database.UnarchiveTemplateVersionParams) error {
|
||||
|
||||
@@ -148,11 +148,12 @@ func (mr *MockStoreMockRecorder) AllUserIDs(ctx, includeSystem any) *gomock.Call
|
||||
}
|
||||
|
||||
// ArchiveChatByID mocks base method.
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ArchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ArchiveChatByID indicates an expected call of ArchiveChatByID.
|
||||
@@ -7632,11 +7633,12 @@ func (mr *MockStoreMockRecorder) TryAcquireLock(ctx, pgTryAdvisoryXactLock any)
|
||||
}
|
||||
|
||||
// UnarchiveChatByID mocks base method.
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnarchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UnarchiveChatByID indicates an expected call of UnarchiveChatByID.
|
||||
|
||||
@@ -54,7 +54,7 @@ type sqlcQuerier interface {
|
||||
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// Archiving templates is a soft delete action, so is reversible.
|
||||
// Archiving prevents the version from being used and discovered
|
||||
// by listing.
|
||||
@@ -844,7 +844,7 @@ type sqlcQuerier interface {
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
|
||||
@@ -10646,7 +10646,8 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
}
|
||||
|
||||
// Archive the middle pin.
|
||||
require.NoError(t, db.ArchiveChatByID(ctx, second.ID))
|
||||
_, err := db.ArchiveChatByID(ctx, second.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archived chat should have pin_order cleared. Remaining
|
||||
// pins keep their original positions; the next mutation
|
||||
|
||||
+110
-11
@@ -4241,14 +4241,63 @@ func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = $1 OR root_chat_id = $1
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, archiveChatByID, id)
|
||||
return err
|
||||
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, archiveChatByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
|
||||
@@ -6168,13 +6217,63 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
|
||||
return err
|
||||
}
|
||||
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = $1::uuid
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, unarchiveChatByID, id)
|
||||
return err
|
||||
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, unarchiveChatByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const unpinChatByID = `-- name: UnpinChatByID :exec
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id OR root_chat_id = @id;
|
||||
-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = @id::uuid;
|
||||
-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: PinChatByID :exec
|
||||
WITH target_chat AS (
|
||||
|
||||
Reference in New Issue
Block a user