fix(coderd/x/chatd): archive chat hard-interrupts active stream (#23758)

Archiving a chat now transitions pending or running chats to waiting
before setting the archived flag. This publishes a status notification
on `ChatStreamNotifyChannel` so `subscribeChatControl` cancels the
active `processChat` context via `ErrInterrupted` — the same codepath
used by the stop button.

The `processChat` cleanup also skips queued-message auto-promotion when
the chat is archived, so archiving behaves like a hard stop rather than
interrupt-and-continue.

Relates to https://github.com/coder/coder/issues/23666
This commit is contained in:
Ethan
2026-04-01 00:23:52 +11:00
committed by GitHub
parent 9fa103929a
commit bbf3fbc830
3 changed files with 224 additions and 9 deletions
+3 -3
View File
@@ -1641,9 +1641,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
}
var err error
// Use chatDaemon when available so it can notify active
// subscribers. Fall back to direct DB for the simple
// archive flag — no streaming state is involved.
// Use chatDaemon when available so it can interrupt active
// processing before broadcasting archive state. Fall back to
// direct DB when no daemon is running.
if archived {
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat)
+47 -6
View File
@@ -1244,17 +1244,57 @@ func (p *Server) EditMessage(
return result, nil
}
// ArchiveChat archives a chat and all descendants, then broadcasts a deleted event.
// ArchiveChat archives a chat and all descendants. If the target chat is
// pending or running, it first transitions the chat back to waiting so active
// processing stops before the archive is broadcast.
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
if err := p.db.ArchiveChatByID(ctx, chat.ID); err != nil {
return xerrors.Errorf("archive chat: %w", err)
statusChat := chat
interrupted := false
if err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("lock chat for archive: %w", err)
}
statusChat = lockedChat
// We do not call setChatWaiting here because it intentionally preserves
// pending chats so queued-message promotion can win. Archiving is a
// harder stop: both pending and running chats must transition to waiting.
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
if err != nil {
return xerrors.Errorf("set chat waiting before archive: %w", err)
}
interrupted = true
}
if err := tx.ArchiveChatByID(ctx, chat.ID); err != nil {
return xerrors.Errorf("archive chat: %w", err)
}
return nil
}, nil); err != nil {
return err
}
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDeleted, nil)
if interrupted {
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
}
statusChat.Archived = true
statusChat.PinOrder = 0
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindDeleted, nil)
return nil
}
@@ -3563,9 +3603,10 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
// the worker and let the processor pick it back up.
if latestChat.Status == database.ChatStatusPending {
status = database.ChatStatusPending
} else if status == database.ChatStatusWaiting {
} else if status == database.ChatStatusWaiting && !latestChat.Archived {
// Queued messages were already admitted through SendMessage,
// so auto-promotion only preserves FIFO order here.
// so auto-promotion only preserves FIFO order here. Archived
// chats skip promotion so archiving behaves like a hard stop.
var promoteErr error
promotedMessage, remainingQueuedMessages, shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(cleanupCtx, tx, latestChat)
if promoteErr != nil {
+174
View File
@@ -297,6 +297,180 @@ func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
require.False(t, fromDB.WorkerID.Valid)
}
func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "archive-pending",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
err = replica.ArchiveChat(ctx, chat)
require.NoError(t, err)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
require.False(t, fromDB.StartedAt.Valid)
require.False(t, fromDB.HeartbeatAt.Valid)
require.True(t, fromDB.Archived)
require.Zero(t, fromDB.PinOrder)
}
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
streamStarted := make(chan struct{})
streamCanceled := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
<-req.Context().Done()
select {
case <-streamCanceled:
default:
close(streamCanceled)
}
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
})
server := newActiveTestServer(t, db, ps)
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "archive-interrupt",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
}, testutil.IntervalFast)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.IntervalFast)
_, events, cancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
defer cancel()
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
err = server.ArchiveChat(ctx, chat)
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-streamCanceled:
return true
default:
return false
}
}, testutil.IntervalFast)
gotWaitingStatus := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
for {
select {
case ev := <-events:
if ev.Type == codersdk.ChatStreamEventTypeStatus &&
ev.Status != nil &&
ev.Status.Status == codersdk.ChatStatusWaiting {
gotWaitingStatus = true
return true
}
default:
return gotWaitingStatus
}
}
}, testutil.IntervalFast)
require.True(t, gotWaitingStatus, "expected a waiting status event after archive")
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Archived &&
fromDB.Status == database.ChatStatusWaiting &&
!fromDB.WorkerID.Valid &&
!fromDB.StartedAt.Valid &&
!fromDB.HeartbeatAt.Valid
}, testutil.IntervalFast)
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queuedMessages, 1)
require.Equal(t, queuedResult.QueuedMessage.ID, queuedMessages[0].ID)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
userMessages := 0
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleUser {
userMessages++
}
}
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
t.Parallel()