mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): unarchive child chats with parents (#23761)
Unarchiving a root chat now restores descendant chats in the database and emits lifecycle events for every affected chat so passive sessions converge without a full refetch. This keeps archive and unarchive symmetric at both the data and watch-stream layers by returning the affected chat family from the database, using those post-update rows for chatd pubsub fanout, and covering descendant lifecycle delivery with a watch-level regression test. Closes #23666
This commit is contained in:
@@ -1570,13 +1570,13 @@ func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UU
|
||||
return q.db.AllUserIDs(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ArchiveChatByID(ctx, id)
|
||||
}
|
||||
@@ -5649,13 +5649,13 @@ func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (q *querier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return q.db.UnarchiveChatByID(ctx, id)
|
||||
}
|
||||
|
||||
@@ -392,14 +392,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("ArchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().ArchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("UnarchiveChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
|
||||
@@ -160,12 +160,12 @@ func (m queryMetricsStore) AllUserIDs(ctx context.Context, includeSystem bool) (
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.ArchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.ArchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("ArchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ArchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, arg database.ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) {
|
||||
@@ -4024,12 +4024,12 @@ func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXact
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m queryMetricsStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.UnarchiveChatByID(ctx, id)
|
||||
r0, r1 := m.s.UnarchiveChatByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UnarchiveChatByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UnarchiveChatByID").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UnarchiveTemplateVersion(ctx context.Context, arg database.UnarchiveTemplateVersionParams) error {
|
||||
|
||||
@@ -148,11 +148,12 @@ func (mr *MockStoreMockRecorder) AllUserIDs(ctx, includeSystem any) *gomock.Call
|
||||
}
|
||||
|
||||
// ArchiveChatByID mocks base method.
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ArchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ArchiveChatByID indicates an expected call of ArchiveChatByID.
|
||||
@@ -7632,11 +7633,12 @@ func (mr *MockStoreMockRecorder) TryAcquireLock(ctx, pgTryAdvisoryXactLock any)
|
||||
}
|
||||
|
||||
// UnarchiveChatByID mocks base method.
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
func (m *MockStore) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnarchiveChatByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UnarchiveChatByID indicates an expected call of UnarchiveChatByID.
|
||||
|
||||
@@ -54,7 +54,7 @@ type sqlcQuerier interface {
|
||||
ActivityBumpWorkspace(ctx context.Context, arg ActivityBumpWorkspaceParams) error
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error)
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// Archiving templates is a soft delete action, so is reversible.
|
||||
// Archiving prevents the version from being used and discovered
|
||||
// by listing.
|
||||
@@ -844,7 +844,7 @@ type sqlcQuerier interface {
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error
|
||||
|
||||
@@ -10646,7 +10646,8 @@ func TestChatPinOrderQueries(t *testing.T) {
|
||||
}
|
||||
|
||||
// Archive the middle pin.
|
||||
require.NoError(t, db.ArchiveChatByID(ctx, second.ID))
|
||||
_, err := db.ArchiveChatByID(ctx, second.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archived chat should have pin_order cleared. Remaining
|
||||
// pins keep their original positions; the next mutation
|
||||
|
||||
+110
-11
@@ -4241,14 +4241,63 @@ func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = $1 OR root_chat_id = $1
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, archiveChatByID, id)
|
||||
return err
|
||||
func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, archiveChatByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
|
||||
@@ -6168,13 +6217,63 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
|
||||
return err
|
||||
}
|
||||
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = $1::uuid
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = $1::uuid OR root_chat_id = $1::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
)
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context
|
||||
FROM chats
|
||||
ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, unarchiveChatByID, id)
|
||||
return err
|
||||
func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, unarchiveChatByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const unpinChatByID = `-- name: UnpinChatByID :exec
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id OR root_chat_id = @id;
|
||||
-- name: ArchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :exec
|
||||
UPDATE chats SET archived = false, updated_at = NOW() WHERE id = @id::uuid;
|
||||
-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
SELECT *
|
||||
FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: PinChatByID :exec
|
||||
WITH target_chat AS (
|
||||
|
||||
+2
-2
@@ -1648,13 +1648,13 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
} else {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
_, err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -1052,6 +1052,102 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "watch root chat",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
childOne, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "watch child 1",
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
childTwo, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "watch child 2",
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
}
|
||||
events[payload.Chat.ID] = payload
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
for _, chatID := range []uuid.UUID{parentChat.ID, childOne.ID, childTwo.ID} {
|
||||
payload, ok := events[chatID]
|
||||
require.True(t, ok, "missing event for chat %s", chatID)
|
||||
require.Equal(t, archived, payload.Chat.Archived)
|
||||
}
|
||||
}
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
t.Run("Unauthenticated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2210,6 +2306,96 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.Empty(t, archivedChats)
|
||||
})
|
||||
|
||||
t.Run("UnarchivesChildren", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "parent chat",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
child1, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "child 1",
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
child2, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "child 2",
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
|
||||
activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Query: "archived:false",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundParent bool
|
||||
var foundChild1 bool
|
||||
var foundChild2 bool
|
||||
for _, chat := range activeChats {
|
||||
switch chat.ID {
|
||||
case parentChat.ID:
|
||||
foundParent = true
|
||||
require.False(t, chat.Archived)
|
||||
case child1.ID:
|
||||
foundChild1 = true
|
||||
require.False(t, chat.Archived)
|
||||
case child2.ID:
|
||||
foundChild2 = true
|
||||
require.False(t, chat.Archived)
|
||||
}
|
||||
}
|
||||
require.True(t, foundParent, "parent should be listed as active")
|
||||
require.True(t, foundChild1, "child1 should be listed as active")
|
||||
require.True(t, foundChild2, "child2 should be listed as active")
|
||||
|
||||
archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Query: "archived:true",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for _, chat := range archivedChats {
|
||||
require.NotEqual(t, parentChat.ID, chat.ID, "parent should not remain archived")
|
||||
require.NotEqual(t, child1.ID, chat.ID, "child1 should not remain archived")
|
||||
require.NotEqual(t, child2.ID, chat.ID, "child2 should not remain archived")
|
||||
}
|
||||
|
||||
dbParent, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, dbParent.Archived, "parent should be unarchived")
|
||||
|
||||
dbChild1, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child1.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, dbChild1.Archived, "child1 should be unarchived")
|
||||
|
||||
dbChild2, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child2.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, dbChild2.Archived, "child2 should be unarchived")
|
||||
})
|
||||
|
||||
t.Run("NotArchived", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+37
-12
@@ -1244,9 +1244,10 @@ func (p *Server) EditMessage(
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ArchiveChat archives a chat and all descendants. If the target chat is
|
||||
// pending or running, it first transitions the chat back to waiting so active
|
||||
// processing stops before the archive is broadcast.
|
||||
// ArchiveChat archives a chat family and broadcasts deleted events for each
|
||||
// affected chat so watching clients converge without a full refetch. If the
|
||||
// target chat is pending or running, it first transitions the chat back to
|
||||
// waiting so active processing stops before the archive is broadcast.
|
||||
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if chat.ID == uuid.Nil {
|
||||
return xerrors.New("chat_id is required")
|
||||
@@ -1254,6 +1255,7 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
var archivedChats []database.Chat
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
@@ -1279,7 +1281,8 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
if err := tx.ArchiveChatByID(ctx, chat.ID); err != nil {
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -1292,24 +1295,39 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
statusChat.Archived = true
|
||||
statusChat.PinOrder = 0
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindDeleted, nil)
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnarchiveChat unarchives a chat and publishes a created event so sidebar
|
||||
// clients are notified that the chat has reappeared.
|
||||
// UnarchiveChat unarchives a chat family and publishes created events for
|
||||
// each affected chat so watching clients see every chat that reappeared.
|
||||
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if chat.ID == uuid.Nil {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
if err := p.db.UnarchiveChatByID(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("unarchive chat: %w", err)
|
||||
return p.applyChatLifecycleTransition(
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("%s chat: %w", action, err)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvents(updatedChats, kind)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3139,6 +3157,13 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
|
||||
Reference in New Issue
Block a user