mirror of
https://github.com/coder/coder.git
synced 2026-06-06 22:48:19 +00:00
fix(coderd/x/chatd): archive chat hard-interrupts active stream (#23758)
Archiving a chat now transitions pending or running chats to waiting before setting the archived flag. This publishes a status notification on `ChatStreamNotifyChannel` so `subscribeChatControl` cancels the active `processChat` context via `ErrInterrupted` — the same codepath used by the stop button. The `processChat` cleanup also skips queued-message auto-promotion when the chat is archived, so archiving behaves like a hard stop rather than interrupt-and-continue. Relates to https://github.com/coder/coder/issues/23666
This commit is contained in:
+3
-3
@@ -1641,9 +1641,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify active
|
||||
// subscribers. Fall back to direct DB for the simple
|
||||
// archive flag — no streaming state is involved.
|
||||
// Use chatDaemon when available so it can interrupt active
|
||||
// processing before broadcasting archive state. Fall back to
|
||||
// direct DB when no daemon is running.
|
||||
if archived {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
|
||||
+47
-6
@@ -1244,17 +1244,57 @@ func (p *Server) EditMessage(
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ArchiveChat archives a chat and all descendants, then broadcasts a deleted event.
|
||||
// ArchiveChat archives a chat and all descendants. If the target chat is
|
||||
// pending or running, it first transitions the chat back to waiting so active
|
||||
// processing stops before the archive is broadcast.
|
||||
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
if chat.ID == uuid.Nil {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
if err := p.db.ArchiveChatByID(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for archive: %w", err)
|
||||
}
|
||||
statusChat = lockedChat
|
||||
|
||||
// We do not call setChatWaiting here because it intentionally preserves
|
||||
// pending chats so queued-message promotion can win. Archiving is a
|
||||
// harder stop: both pending and running chats must transition to waiting.
|
||||
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
|
||||
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat waiting before archive: %w", err)
|
||||
}
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
if err := tx.ArchiveChatByID(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDeleted, nil)
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
statusChat.Archived = true
|
||||
statusChat.PinOrder = 0
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindDeleted, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3563,9 +3603,10 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
// the worker and let the processor pick it back up.
|
||||
if latestChat.Status == database.ChatStatusPending {
|
||||
status = database.ChatStatusPending
|
||||
} else if status == database.ChatStatusWaiting {
|
||||
} else if status == database.ChatStatusWaiting && !latestChat.Archived {
|
||||
// Queued messages were already admitted through SendMessage,
|
||||
// so auto-promotion only preserves FIFO order here.
|
||||
// so auto-promotion only preserves FIFO order here. Archived
|
||||
// chats skip promotion so archiving behaves like a hard stop.
|
||||
var promoteErr error
|
||||
promotedMessage, remainingQueuedMessages, shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(cleanupCtx, tx, latestChat)
|
||||
if promoteErr != nil {
|
||||
|
||||
@@ -297,6 +297,180 @@ func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "archive-pending",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = replica.ArchiveChat(ctx, chat)
|
||||
require.NoError(t, err)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
require.False(t, fromDB.StartedAt.Valid)
|
||||
require.False(t, fromDB.HeartbeatAt.Valid)
|
||||
require.True(t, fromDB.Archived)
|
||||
require.Zero(t, fromDB.PinOrder)
|
||||
}
|
||||
|
||||
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
streamStarted := make(chan struct{})
|
||||
streamCanceled := make(chan struct{})
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
defer close(chunks)
|
||||
chunks <- chattest.OpenAITextChunks("partial")[0]
|
||||
select {
|
||||
case <-streamStarted:
|
||||
default:
|
||||
close(streamStarted)
|
||||
}
|
||||
<-req.Context().Done()
|
||||
select {
|
||||
case <-streamCanceled:
|
||||
default:
|
||||
close(streamCanceled)
|
||||
}
|
||||
}()
|
||||
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
||||
})
|
||||
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "archive-interrupt",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
select {
|
||||
case <-streamStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
|
||||
err = server.ArchiveChat(ctx, chat)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
select {
|
||||
case <-streamCanceled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
gotWaitingStatus := false
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
for {
|
||||
select {
|
||||
case ev := <-events:
|
||||
if ev.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
ev.Status != nil &&
|
||||
ev.Status.Status == codersdk.ChatStatusWaiting {
|
||||
gotWaitingStatus = true
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return gotWaitingStatus
|
||||
}
|
||||
}
|
||||
}, testutil.IntervalFast)
|
||||
require.True(t, gotWaitingStatus, "expected a waiting status event after archive")
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Archived &&
|
||||
fromDB.Status == database.ChatStatusWaiting &&
|
||||
!fromDB.WorkerID.Valid &&
|
||||
!fromDB.StartedAt.Valid &&
|
||||
!fromDB.HeartbeatAt.Valid
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queuedMessages, 1)
|
||||
require.Equal(t, queuedResult.QueuedMessage.ID, queuedMessages[0].ID)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
userMessages := 0
|
||||
for _, msg := range messages {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
userMessages++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user