fix(coderd): unarchive child chats with parents (#23761)

Unarchiving a root chat now restores descendant chats in the database
and emits lifecycle events for every affected chat so passive sessions
converge without a full refetch.

This keeps archive and unarchive symmetric at both the data and
watch-stream layers by returning the affected chat family from the
database, using those post-update rows for chatd pubsub fanout, and
covering descendant lifecycle delivery with a watch-level regression
test.

Closes #23666
This commit is contained in:
Ethan
2026-04-01 15:30:25 +11:00
committed by GitHub
parent 1d16ff1ca6
commit 5cba59af79
11 changed files with 383 additions and 55 deletions
+6 -6
View File
@@ -1570,13 +1570,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
return q.db.AllUserIDs(ctx, includeSystem)
}
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
return nil, err
}
return q.db.ArchiveChatByID(ctx, id)
}
@@ -5649,13 +5649,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
return err
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
return nil, err
}
return q.db.UnarchiveChatByID(ctx, id)
}
+4 -4
View File
@@ -392,14 +392,14 @@ func (s *MethodTestSuite) TestChats() {
s.Run("ArchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
}))
s.Run("UnarchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
}))
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
+6 -6
View File
@@ -160,12 +160,12 @@ func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) (
return r0, r1
}
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0 := m.s.ArchiveChatByID(ctx, id)
r0, r1 := m.s.ArchiveChatByID(ctx, id)
m.queryLatencies.WithLabelValues("ArchiveChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ArchiveChatByID").Inc()
return r0
return r0, r1
}
func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
@@ -4024,12 +4024,12 @@ func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXact
return r0, r1
}
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0 := m.s.UnarchiveChatByID(ctx, id)
r0, r1 := m.s.UnarchiveChatByID(ctx, id)
m.queryLatencies.WithLabelValues("UnarchiveChatByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnarchiveChatByID").Inc()
return r0
return r0, r1
}
func (m queryMetricsStore) UnarchiveTemplateVersion(ctx context.Context, arg database.UnarchiveTemplateVersionParams) error {
+8 -6
View File
@@ -148,11 +148,12 @@ func (mr *MockStoreMockRecorder) AllUserIDs(ctx, includeSystem any) *gomock.Call
}
// ArchiveChatByID mocks base method.
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ArchiveChatByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ArchiveChatByID indicates an expected call of ArchiveChatByID.
@@ -7632,11 +7633,12 @@ func (mr *MockStoreMockRecorder) TryAcquireLock(ctx, pgTryAdvisoryXactLock any)
}
// UnarchiveChatByID mocks base method.
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnarchiveChatByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UnarchiveChatByID indicates an expected call of UnarchiveChatByID.
+2 -2
View File
@@ -54,7 +54,7 @@ type sqlcQuerier interface {
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
// AllUserIDs returns all UserIDs regardless of user status or deletion.
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
ArchiveChatByID(ctx context.Context, id uuid.UUID) error
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
// Archiving templates is a soft delete action, so is reversible.
// Archiving prevents the version from being used and discovered
// by listing.
@@ -844,7 +844,7 @@ type sqlcQuerier interface {
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
+2 -1
View File
@@ -10646,7 +10646,8 @@ func TestChatPinOrderQueries(t *testing.T) {
}
// Archive the middle pin.
require.NoError(t, db.ArchiveChatByID(ctx, second.ID))
_, err := db.ArchiveChatByID(ctx, second.ID)
require.NoError(t, err)
// Archived chat should have pin_order cleared. Remaining
// pins keep their original positions; the next mutation
+110 -11
View File
@@ -4241,14 +4241,63 @@ func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal
return items, nil
}
const archiveChatByID = `-- name: ArchiveChatByID :exec
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = $1 OR root_chat_id = $1
const archiveChatByID = `-- name: ArchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, archiveChatByID, id)
return err
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
rows, err := q.db.QueryContext(ctx, archiveChatByID, id)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Chat
for rows.Next() {
var i Chat
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
@@ -6168,13 +6217,63 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
return err
}
const unarchiveChatByID = `-- name: UnarchiveChatByID :exec
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = $1::uuid
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = false, updated_at = NOW()
WHERE id = $1::uuid OR root_chat_id = $1::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
)
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
FROM chats
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
`
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, unarchiveChatByID, id)
return err
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
rows, err := q.db.QueryContext(ctx, unarchiveChatByID, id)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Chat
for rows.Next() {
var i Chat
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const unpinChatByID = `-- name: UnpinChatByID :exec
+20 -5
View File
@@ -1,9 +1,24 @@
-- name: ArchiveChatByID :exec
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = @id OR root_chat_id = @id;
-- name: ArchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = true, pin_order = 0, updated_at = NOW()
WHERE id = @id::uuid OR root_chat_id = @id::uuid
RETURNING *
)
SELECT *
FROM chats
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
-- name: UnarchiveChatByID :exec
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = @id::uuid;
-- name: UnarchiveChatByID :many
WITH chats AS (
UPDATE chats
SET archived = false, updated_at = NOW()
WHERE id = @id::uuid OR root_chat_id = @id::uuid
RETURNING *
)
SELECT *
FROM chats
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
-- name: PinChatByID :exec
WITH target_chat AS (
+2 -2
View File
@@ -1648,13 +1648,13 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
_, err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
} else {
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
_, err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
}
if err != nil {
+186
View File
@@ -1052,6 +1052,102 @@ func TestWatchChats(t *testing.T) {
}
})
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "watch root chat",
},
},
})
require.NoError(t, err)
childOne, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "watch child 1",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
childTwo, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "watch child 2",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil)
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
t.Helper()
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
for len(events) < 3 {
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var payload coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &payload)
require.NoError(t, err)
if payload.Kind != expectedKind {
continue
}
events[payload.Chat.ID] = payload
}
return events
}
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
t.Helper()
require.Len(t, events, 3)
for _, chatID := range []uuid.UUID{parentChat.ID, childOne.ID, childTwo.ID} {
payload, ok := events[chatID]
require.True(t, ok, "missing event for chat %s", chatID)
require.Equal(t, archived, payload.Chat.Archived)
}
}
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
assertLifecycleEvents(deletedEvents, true)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
assertLifecycleEvents(createdEvents, false)
})
t.Run("Unauthenticated", func(t *testing.T) {
t.Parallel()
@@ -2210,6 +2306,96 @@ func TestUnarchiveChat(t *testing.T) {
require.Empty(t, archivedChats)
})
t.Run("UnarchivesChildren", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "parent chat",
},
},
})
require.NoError(t, err)
child1, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child 1",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
child2, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "child 2",
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
})
require.NoError(t, err)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Query: "archived:false",
})
require.NoError(t, err)
var foundParent bool
var foundChild1 bool
var foundChild2 bool
for _, chat := range activeChats {
switch chat.ID {
case parentChat.ID:
foundParent = true
require.False(t, chat.Archived)
case child1.ID:
foundChild1 = true
require.False(t, chat.Archived)
case child2.ID:
foundChild2 = true
require.False(t, chat.Archived)
}
}
require.True(t, foundParent, "parent should be listed as active")
require.True(t, foundChild1, "child1 should be listed as active")
require.True(t, foundChild2, "child2 should be listed as active")
archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Query: "archived:true",
})
require.NoError(t, err)
for _, chat := range archivedChats {
require.NotEqual(t, parentChat.ID, chat.ID, "parent should not remain archived")
require.NotEqual(t, child1.ID, chat.ID, "child1 should not remain archived")
require.NotEqual(t, child2.ID, chat.ID, "child2 should not remain archived")
}
dbParent, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), parentChat.ID)
require.NoError(t, err)
require.False(t, dbParent.Archived, "parent should be unarchived")
dbChild1, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child1.ID)
require.NoError(t, err)
require.False(t, dbChild1.Archived, "child1 should be unarchived")
dbChild2, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child2.ID)
require.NoError(t, err)
require.False(t, dbChild2.Archived, "child2 should be unarchived")
})
t.Run("NotArchived", func(t *testing.T) {
t.Parallel()
+37 -12
View File
@@ -1244,9 +1244,10 @@ func (p *Server) EditMessage(
return result, nil
}
// ArchiveChat archives a chat and all descendants. If the target chat is
// pending or running, it first transitions the chat back to waiting so active
// processing stops before the archive is broadcast.
// ArchiveChat archives a chat family and broadcasts deleted events for each
// affected chat so watching clients converge without a full refetch. If the
// target chat is pending or running, it first transitions the chat back to
// waiting so active processing stops before the archive is broadcast.
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
@@ -1254,6 +1255,7 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
statusChat := chat
interrupted := false
var archivedChats []database.Chat
if err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
@@ -1279,7 +1281,8 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
interrupted = true
}
if err := tx.ArchiveChatByID(ctx, chat.ID); err != nil {
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("archive chat: %w", err)
}
return nil
@@ -1292,24 +1295,39 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
}
statusChat.Archived = true
statusChat.PinOrder = 0
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindDeleted, nil)
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
return nil
}
// UnarchiveChat unarchives a chat and publishes a created event so sidebar
// clients are notified that the chat has reappeared.
// UnarchiveChat unarchives a chat family and publishes created events for
// each affected chat so watching clients see every chat that reappeared.
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
if err := p.db.UnarchiveChatByID(ctx, chat.ID); err != nil {
return xerrors.Errorf("unarchive chat: %w", err)
return p.applyChatLifecycleTransition(
ctx,
chat.ID,
"unarchive",
coderdpubsub.ChatEventKindCreated,
p.db.UnarchiveChatByID,
)
}
func (p *Server) applyChatLifecycleTransition(
ctx context.Context,
chatID uuid.UUID,
action string,
kind coderdpubsub.ChatEventKind,
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
) error {
updatedChats, err := transition(ctx, chatID)
if err != nil {
return xerrors.Errorf("%s chat: %w", action, err)
}
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
p.publishChatPubsubEvents(updatedChats, kind)
return nil
}
@@ -3139,6 +3157,13 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
}
}
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
for _, chat := range chats {
p.publishChatPubsubEvent(chat, kind, nil)
}
}
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
// pubsub so that all replicas can push updates to watching clients.
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {