From 6b0518d051cabd234d724e088371bf476d102d68 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Wed, 6 May 2026 19:11:56 +0300 Subject: [PATCH] fix: state-aware queued message promotion (#24819) PromoteQueued now branches on chat status: synth tool results before the user message on requires_action, deferred reorder + Waiting on running so the worker's persist+auto-promote keeps partial output. Stale heartbeat falls through to the synchronous path; GetStaleChats picks up Waiting+queue to recover post-cleanup-crash. Endpoint returns 202. Closes CODAGT-119 --- coderd/database/dbauthz/dbauthz.go | 11 + coderd/database/dbauthz/dbauthz_test.go | 7 + coderd/database/dbmetrics/querymetrics.go | 8 + coderd/database/dbmock/dbmock.go | 15 + coderd/database/querier.go | 12 +- coderd/database/queries.sql.go | 44 +- coderd/database/queries/chats.sql | 32 +- coderd/exp_chats.go | 6 +- coderd/exp_chats_test.go | 272 ++- coderd/x/chatd/chatd.go | 224 ++- coderd/x/chatd/chatd_test.go | 1505 +++++++++++++++++ coderd/x/chatd/export_test.go | 61 + site/src/api/api.ts | 5 +- .../pages/AgentsPage/AgentChatPage.test.ts | 74 + site/src/pages/AgentsPage/AgentChatPage.tsx | 95 +- .../chatStore.createStore.test.ts | 68 + .../components/ChatConversation/chatStore.ts | 83 + .../ChatConversation/useChatStore.ts | 10 +- 18 files changed, 2431 insertions(+), 101 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index d5e56e504e..eaea49fb9f 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -6057,6 +6057,17 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU return q.db.RemoveUserFromGroups(ctx, arg) } +func (q *querier) ReorderChatQueuedMessageToFront(ctx context.Context, arg database.ReorderChatQueuedMessageToFrontParams) (int64, error) { + chat, err := q.db.GetChatByID(ctx, arg.ChatID) + if err != nil { + return 0, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return 0, err + } + return q.db.ReorderChatQueuedMessageToFront(ctx, arg) +} + func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, arg database.ResolveUserChatSpendLimitParams) (database.ResolveUserChatSpendLimitRow, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil { return database.ResolveUserChatSpendLimitRow{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index e58d01264c..53c0ed35bd 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1042,6 +1042,13 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().PopNextQueuedMessage(gomock.Any(), chat.ID).Return(qm, nil).AnyTimes() check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(qm) })) + s.Run("ReorderChatQueuedMessageToFront", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.ReorderChatQueuedMessageToFrontParams{ChatID: chat.ID, TargetID: 123} + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().ReorderChatQueuedMessageToFront(gomock.Any(), arg).Return(int64(1), nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1)) + })) s.Run("UpdateChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatByIDParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index d9a9963f9e..7930faabf5 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -4344,6 +4344,14 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas return r0, r1 } +func (m queryMetricsStore) ReorderChatQueuedMessageToFront(ctx context.Context, arg database.ReorderChatQueuedMessageToFrontParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.ReorderChatQueuedMessageToFront(ctx, arg) + m.queryLatencies.WithLabelValues("ReorderChatQueuedMessageToFront").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ReorderChatQueuedMessageToFront").Inc() + return r0, r1 +} + func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID database.ResolveUserChatSpendLimitParams) (database.ResolveUserChatSpendLimitRow, error) { start := time.Now() r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index ead72d0945..e5be0fa48b 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -8233,6 +8233,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg) } +// ReorderChatQueuedMessageToFront mocks base method. +func (m *MockStore) ReorderChatQueuedMessageToFront(ctx context.Context, arg database.ReorderChatQueuedMessageToFrontParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReorderChatQueuedMessageToFront", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReorderChatQueuedMessageToFront indicates an expected call of ReorderChatQueuedMessageToFront. +func (mr *MockStoreMockRecorder) ReorderChatQueuedMessageToFront(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReorderChatQueuedMessageToFront", reflect.TypeOf((*MockStore)(nil).ReorderChatQueuedMessageToFront), ctx, arg) +} + // ResolveUserChatSpendLimit mocks base method. func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, arg database.ResolveUserChatSpendLimitParams) (database.ResolveUserChatSpendLimitRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 935453f2fc..5d025f4b25 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -591,10 +591,13 @@ type sqlcQuerier interface { GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error) GetRuntimeConfig(ctx context.Context, key string) (string, error) - // Find chats that appear stuck and need recovery. This covers: + // Find chats that appear stuck and need recovery: // 1. Running chats whose heartbeat has expired (worker crash). - // 2. Chats awaiting client action (requires_action) past the - // timeout threshold (client disappeared). + // 2. requires_action chats past the timeout threshold (client + // disappeared). + // 3. Waiting chats with a non-empty queue and stale updated_at + // (deferred-promote stranding when the worker dies before its + // post-cancel cleanup runs). GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error) @@ -1012,6 +1015,9 @@ type sqlcQuerier interface { ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) + // Mutates only created_at on the target row; ids are unchanged so + // consumers can keep tracking queued messages by id. + ReorderChatQueuedMessageToFront(ctx context.Context, arg ReorderChatQueuedMessageToFrontParams) (int64, error) // Resolves the effective spend limit for a user using the hierarchy: // 1. Individual user override (highest priority, applies globally across // all organizations since it lives on the users table) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 03fa2904a9..e6cc1a361f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -6808,7 +6808,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages WHERE chat_id = $1 -ORDER BY id ASC +ORDER BY created_at ASC, id ASC ` func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error) { @@ -7311,12 +7311,21 @@ WHERE AND heartbeat_at < $1::timestamptz) OR (status = 'requires_action'::chat_status AND updated_at < $1::timestamptz) + OR (status = 'waiting'::chat_status + AND updated_at < $1::timestamptz + AND EXISTS ( + SELECT 1 FROM chat_queued_messages cqm + WHERE cqm.chat_id = chats.id + )) ` -// Find chats that appear stuck and need recovery. This covers: +// Find chats that appear stuck and need recovery: // 1. Running chats whose heartbeat has expired (worker crash). -// 2. Chats awaiting client action (requires_action) past the -// timeout threshold (client disappeared). +// 2. requires_action chats past the timeout threshold (client +// disappeared). +// 3. Waiting chats with a non-empty queue and stale updated_at +// (deferred-promote stranding when the worker dies before its +// post-cancel cleanup runs). func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error) { rows, err := q.db.QueryContext(ctx, getStaleChats, staleThreshold) if err != nil { @@ -7946,7 +7955,7 @@ DELETE FROM chat_queued_messages WHERE id = ( SELECT cqm.id FROM chat_queued_messages cqm WHERE cqm.chat_id = $1 - ORDER BY cqm.id ASC + ORDER BY cqm.created_at ASC, cqm.id ASC LIMIT 1 ) RETURNING id, chat_id, content, created_at, model_config_id @@ -7965,6 +7974,31 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) return i, err } +const reorderChatQueuedMessageToFront = `-- name: ReorderChatQueuedMessageToFront :execrows +UPDATE chat_queued_messages AS target +SET created_at = ( + SELECT MIN(inner_cqm.created_at) - INTERVAL '1 microsecond' + FROM chat_queued_messages AS inner_cqm + WHERE inner_cqm.chat_id = $1 +) +WHERE target.id = $2 AND target.chat_id = $1 +` + +type ReorderChatQueuedMessageToFrontParams struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + TargetID int64 `db:"target_id" json:"target_id"` +} + +// Mutates only created_at on the target row; ids are unchanged so +// consumers can keep tracking queued messages by id. +func (q *sqlQuerier) ReorderChatQueuedMessageToFront(ctx context.Context, arg ReorderChatQueuedMessageToFrontParams) (int64, error) { + result, err := q.db.ExecContext(ctx, reorderChatQueuedMessageToFront, arg.ChatID, arg.TargetID) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + const resolveUserChatSpendLimit = `-- name: ResolveUserChatSpendLimit :one SELECT CASE WHEN NOT cfg.enabled THEN -1 diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 16c3b45da9..54aea614f9 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -760,10 +760,13 @@ RETURNING *; -- name: GetStaleChats :many --- Find chats that appear stuck and need recovery. This covers: +-- Find chats that appear stuck and need recovery: -- 1. Running chats whose heartbeat has expired (worker crash). --- 2. Chats awaiting client action (requires_action) past the --- timeout threshold (client disappeared). +-- 2. requires_action chats past the timeout threshold (client +-- disappeared). +-- 3. Waiting chats with a non-empty queue and stale updated_at +-- (deferred-promote stranding when the worker dies before its +-- post-cancel cleanup runs). SELECT * FROM @@ -772,7 +775,13 @@ WHERE (status = 'running'::chat_status AND heartbeat_at < @stale_threshold::timestamptz) OR (status = 'requires_action'::chat_status - AND updated_at < @stale_threshold::timestamptz); + AND updated_at < @stale_threshold::timestamptz) + OR (status = 'waiting'::chat_status + AND updated_at < @stale_threshold::timestamptz + AND EXISTS ( + SELECT 1 FROM chat_queued_messages cqm + WHERE cqm.chat_id = chats.id + )); -- name: UpdateChatHeartbeats :many -- Bumps the heartbeat timestamp for the given set of chat IDs, @@ -916,7 +925,7 @@ RETURNING *; -- name: GetChatQueuedMessages :many SELECT * FROM chat_queued_messages WHERE chat_id = @chat_id -ORDER BY id ASC; +ORDER BY created_at ASC, id ASC; -- name: DeleteChatQueuedMessage :exec DELETE FROM chat_queued_messages WHERE id = @id AND chat_id = @chat_id; @@ -929,11 +938,22 @@ DELETE FROM chat_queued_messages WHERE id = ( SELECT cqm.id FROM chat_queued_messages cqm WHERE cqm.chat_id = @chat_id - ORDER BY cqm.id ASC + ORDER BY cqm.created_at ASC, cqm.id ASC LIMIT 1 ) RETURNING *; +-- name: ReorderChatQueuedMessageToFront :execrows +-- Mutates only created_at on the target row; ids are unchanged so +-- consumers can keep tracking queued messages by id. +UPDATE chat_queued_messages AS target +SET created_at = ( + SELECT MIN(inner_cqm.created_at) - INTERVAL '1 microsecond' + FROM chat_queued_messages AS inner_cqm + WHERE inner_cqm.chat_id = @chat_id +) +WHERE target.id = @target_id AND target.chat_id = @chat_id; + -- name: GetLastChatMessageByRole :one SELECT * diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 94a80188c4..032e1518d2 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3193,7 +3193,7 @@ func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request return } - promoteResult, txErr := api.chatDaemon.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + _, txErr := api.chatDaemon.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ ChatID: chatID, CreatedBy: apiKey.UserID, QueuedMessageID: queuedMessageID, @@ -3216,7 +3216,9 @@ func (api *API) promoteChatQueuedMessage(rw http.ResponseWriter, r *http.Request return } - httpapi.Write(ctx, rw, http.StatusOK, convertChatMessage(promoteResult.PromotedMessage)) + httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.Response{ + Message: "Queued message promotion accepted.", + }) } // markChatAsRead updates the last read message ID for a chat to the diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index d9c36fc6e8..b68b0a0d5b 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/externalauth" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" @@ -6096,7 +6097,7 @@ func TestWatchChatsStatusChangeCarriesUpdatedLastModelConfigID(t *testing.T) { ) require.NoError(t, err) defer promoteRes.Body.Close() - require.Equal(t, http.StatusOK, promoteRes.StatusCode) + require.Equal(t, http.StatusAccepted, promoteRes.StatusCode) event := waitForChatWatchStatusChangeEvent(ctx, t, conn, chat.ID) require.Equal(t, modelConfigB.ID, event.Chat.LastModelConfigID) @@ -8163,24 +8164,11 @@ func TestPromoteChatQueuedMessage(t *testing.T) { ) require.NoError(t, err) defer promoteRes.Body.Close() - require.Equal(t, http.StatusOK, promoteRes.StatusCode) + require.Equal(t, http.StatusAccepted, promoteRes.StatusCode) - var promoted codersdk.ChatMessage - err = json.NewDecoder(promoteRes.Body).Decode(&promoted) - require.NoError(t, err) - require.NotZero(t, promoted.ID) - require.Equal(t, chat.ID, promoted.ChatID) - require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role) - - foundPromotedText := false - for _, part := range promoted.Content { - if part.Type == codersdk.ChatMessagePartTypeText && - part.Text == queuedText { - foundPromotedText = true - break - } - } - require.True(t, foundPromotedText) + var resp codersdk.Response + require.NoError(t, json.NewDecoder(promoteRes.Body).Decode(&resp)) + require.NotEmpty(t, resp.Message) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) @@ -8188,6 +8176,19 @@ func TestPromoteChatQueuedMessage(t *testing.T) { require.NotEqual(t, queuedMessage.ID, queued.ID) } + foundPromoted := false + for _, msg := range messagesResult.Messages { + if msg.Role != codersdk.ChatMessageRoleUser { + continue + } + for _, part := range msg.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText { + foundPromoted = true + } + } + } + require.True(t, foundPromoted, "promoted message must appear in chat history") + queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) for _, queued := range queuedMessages { @@ -8246,23 +8247,26 @@ func TestPromoteChatQueuedMessage(t *testing.T) { ) require.NoError(t, err) defer promoteRes.Body.Close() - require.Equal(t, http.StatusOK, promoteRes.StatusCode) + require.Equal(t, http.StatusAccepted, promoteRes.StatusCode) - var promoted codersdk.ChatMessage - err = json.NewDecoder(promoteRes.Body).Decode(&promoted) + var resp codersdk.Response + require.NoError(t, json.NewDecoder(promoteRes.Body).Decode(&resp)) + require.NotEmpty(t, resp.Message) + + messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) - require.NotZero(t, promoted.ID) - require.Equal(t, chat.ID, promoted.ChatID) - require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role) - - foundPromotedText := false - for _, part := range promoted.Content { - if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText { - foundPromotedText = true - break + foundPromoted := false + for _, msg := range messagesResult.Messages { + if msg.Role != codersdk.ChatMessageRoleUser { + continue + } + for _, part := range msg.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText { + foundPromoted = true + } } } - require.True(t, foundPromotedText) + require.True(t, foundPromoted, "promoted message must appear in chat history") queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) @@ -8392,6 +8396,212 @@ func TestPromoteChatQueuedMessage(t *testing.T) { require.ErrorAs(t, promoteErr, &promoteSDKErr) require.Contains(t, promoteSDKErr.Message, "archived") }) + + t.Run("WhileRequiresAction", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const dynamicToolName = "my_dynamic_tool" + dynamicTools := []mcp.Tool{{ + Name: dynamicToolName, + Description: "a test dynamic tool", + InputSchema: mcp.ToolInputSchema{Type: "object"}, + }} + dtJSON, err := json.Marshal(dynamicTools) + require.NoError(t, err) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "promote queued requires-action route test", + DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true}, + }) + require.NoError(t, err) + + const pendingToolCallID = "call_pending" + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: pendingToolCallID, + ToolName: dynamicToolName, + Args: json.RawMessage(`{"x":1}`), + }}) + require.NoError(t, err) + + _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{modelConfig.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(assistantContent.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRequiresAction, + }) + require.NoError(t, err) + + const queuedText = "queued message for requires-action promote" + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(queuedText), + }) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage( + dbauthz.AsSystemRestricted(ctx), + database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + }, + ) + require.NoError(t, err) + + promoteRes, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID), + nil, + ) + require.NoError(t, err) + defer promoteRes.Body.Close() + require.Equal(t, http.StatusAccepted, promoteRes.StatusCode) + + var resp codersdk.Response + require.NoError(t, json.NewDecoder(promoteRes.Body).Decode(&resp)) + require.NotEmpty(t, resp.Message) + + messages, err := db.GetChatMessagesByChatID(dbauthz.AsSystemRestricted(ctx), database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + syntheticID int64 + promotedID int64 + ) + for _, msg := range messages { + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if msg.Role == database.ChatMessageRoleTool && + part.Type == codersdk.ChatMessagePartTypeToolResult && + part.ToolCallID == pendingToolCallID && + part.IsError { + syntheticID = msg.ID + } + if msg.Role == database.ChatMessageRoleUser && + part.Type == codersdk.ChatMessagePartTypeText && + part.Text == queuedText { + promotedID = msg.ID + } + } + } + require.NotZero(t, syntheticID, + "expected a synthetic error tool result for the pending tool call") + require.NotZero(t, promotedID, + "expected the promoted user message in chat history") + require.Less(t, syntheticID, promotedID, + "synthetic tool result must precede the promoted user message") + + queuedRemaining, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + for _, qm := range queuedRemaining { + require.NotEqual(t, queuedMessage.ID, qm.ID) + } + }) + + t.Run("WhileRunning", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "promote queued running route test", + }) + require.NoError(t, err) + + // Simulate an active worker by setting status to running. + // We do not start a real worker; the running-case behavior + // (reorder + set waiting + clear worker) does not depend on + // one. The deferred auto-promote is exercised by the + // chatd-package tests where a real worker is involved. + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("running-promote"), + }) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage( + dbauthz.AsSystemRestricted(ctx), + database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + }, + ) + require.NoError(t, err) + + promoteRes, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID), + nil, + ) + require.NoError(t, err) + defer promoteRes.Body.Close() + require.Equal(t, http.StatusAccepted, promoteRes.StatusCode) + + var resp codersdk.Response + require.NoError(t, json.NewDecoder(promoteRes.Body).Decode(&resp)) + require.NotEmpty(t, resp.Message) + + after, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, after.Status, + "running-case promote must transition chat to waiting") + require.False(t, after.WorkerID.Valid, + "running-case promote must clear WorkerID") + + queuedRemaining, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queuedRemaining, 1) + require.Equal(t, queuedMessage.ID, queuedRemaining[0].ID, + "queued message ID must stay stable across reorder") + }) } func TestChatUsageLimitOverrideRoutes(t *testing.T) { diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 49df083dd4..3978b4e462 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -1195,6 +1195,9 @@ type PromoteQueuedOptions struct { // PromoteQueuedResult contains post-promotion message metadata. type PromoteQueuedResult struct { + // PromotedMessage is the inserted user message. For a chat that + // was running at promote time, the insertion is deferred to the + // worker's auto-promote and PromotedMessage is the zero value. PromotedMessage database.ChatMessage } @@ -2042,7 +2045,10 @@ func (p *Server) DeleteQueued( return nil } -// PromoteQueued promotes a queued message into chat history and marks the chat pending. +// PromoteQueued promotes a queued message into chat history. On a +// running chat with a fresh worker heartbeat the promote is deferred +// to the worker's persist+auto-promote so partial assistant output +// is not lost; otherwise it inserts the user message synchronously. func (p *Server) PromoteQueued( ctx context.Context, opts PromoteQueuedOptions, @@ -2052,10 +2058,12 @@ func (p *Server) PromoteQueued( } var ( - result PromoteQueuedResult - promoted database.ChatMessage - updatedChat database.Chat - remainingQueue []database.ChatQueuedMessage + result PromoteQueuedResult + promoted database.ChatMessage + updatedChat database.Chat + remainingQueue []database.ChatQueuedMessage + deferred bool + syntheticResults []database.ChatMessage ) txErr := p.db.InTx(func(tx database.Store) error { @@ -2087,7 +2095,46 @@ func (p *Server) PromoteQueued( } } if !found { - return xerrors.New("queued message not found") + return xerrors.Errorf("queued message %d not found in chat %s", opts.QueuedMessageID, opts.ChatID) + } + + // Setting pending would trip persistStep's ownership guard + // and drop the worker's partial output. Set waiting and + // reorder the queued row so the worker's auto-promote picks + // it up after the persist. + heartbeatFresh := lockedChat.HeartbeatAt.Valid && + p.clock.Now().Sub(lockedChat.HeartbeatAt.Time) < p.inFlightChatStaleAfter + if lockedChat.Status == database.ChatStatusRunning && heartbeatFresh { + rowsAffected, err := tx.ReorderChatQueuedMessageToFront(ctx, database.ReorderChatQueuedMessageToFrontParams{ + ChatID: opts.ChatID, + TargetID: opts.QueuedMessageID, + }) + if err != nil { + return xerrors.Errorf("reorder queued message to front: %w", err) + } + // Defensive guard against a future non-chat-locked + // queue mutator. The found check above makes this a + // no-op on the current code path. + if rowsAffected != 1 { + return xerrors.Errorf("reorder queued message to front affected %d rows, want 1", rowsAffected) + } + updatedChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: opts.ChatID, + Status: database.ChatStatusWaiting, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: pqtype.NullRawMessage{}, + }) + if err != nil { + return xerrors.Errorf("set chat to waiting for deferred promote: %w", err) + } + remainingQueue, err = tx.GetChatQueuedMessages(ctx, opts.ChatID) + if err != nil { + return xerrors.Errorf("get remaining queue after reorder: %w", err) + } + deferred = true + return nil } effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( @@ -2100,6 +2147,20 @@ func (p *Server) PromoteQueued( return err } + // Without synthetic results, the next turn would carry + // unresolved tool_call parts; the LLM API rejects this and the + // chat dead-ends in error. + if lockedChat.Status == database.ChatStatusRequiresAction { + inserted, err := insertSyntheticToolResultsTx( + ctx, tx, lockedChat, + "Tool execution interrupted by queued message promotion", + ) + if err != nil { + return xerrors.Errorf("insert synthetic tool results: %w", err) + } + syntheticResults = inserted + } + err = tx.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{ ID: opts.QueuedMessageID, ChatID: opts.ChatID, @@ -2135,6 +2196,22 @@ func (p *Server) PromoteQueued( return PromoteQueuedResult{}, txErr } + if deferred { + // Skip publishMessage and signalWake: there is no synchronous + // user message yet, and the active worker's interrupt path + // signals its own auto-promote follow-up. + p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeQueueUpdate, + QueuedMessages: db2sdk.ChatQueuedMessages(remainingQueue), + }) + p.publishChatStreamNotify(opts.ChatID, coderdpubsub.ChatStreamNotifyMessage{ + QueueUpdate: true, + }) + p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID) + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) + return result, nil + } + p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{ Type: codersdk.ChatStreamEventTypeQueueUpdate, QueuedMessages: db2sdk.ChatQueuedMessages(remainingQueue), @@ -2142,6 +2219,11 @@ func (p *Server) PromoteQueued( p.publishChatStreamNotify(opts.ChatID, coderdpubsub.ChatStreamNotifyMessage{ QueueUpdate: true, }) + // Publish synth rows before the user message so live viewers + // see the interruption inline. + for _, msg := range syntheticResults { + p.publishMessage(opts.ChatID, msg) + } p.publishMessage(opts.ChatID, promoted) p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID) p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) @@ -2410,7 +2492,8 @@ func (p *Server) InterruptChat( if locked.Status != database.ChatStatusRequiresAction { return nil } - return insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user") + _, err := insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user") + return err }, nil); txErr != nil { p.logger.Error(ctx, "failed to insert synthetic tool results during interrupt", slog.F("chat_id", chat.ID), @@ -5223,6 +5306,7 @@ func (p *Server) trackWorkspaceUsage( type finishActiveChatResult struct { updatedChat database.Chat promotedMessage *database.ChatMessage + syntheticToolResults []database.ChatMessage remainingQueuedMessages []database.ChatQueuedMessage shouldPublishQueueUpdate bool } @@ -5259,6 +5343,32 @@ func (p *Server) finishActiveChat( switch { case latestChat.Status == database.ChatStatusPending: status = database.ChatStatusPending + case latestChat.Status == database.ChatStatusWaiting && status != database.ChatStatusWaiting && !latestChat.Archived: + // PromoteQueued's deferred path won the status race. + // Insert synthetic tool results before auto-promoting, + // or a RequiresAction worker outcome reintroduces the + // stops-dead bug this PR exists to fix. + inserted, synthErr := insertSyntheticToolResultsTx( + ctx, tx, latestChat, + "Tool execution interrupted by queued message promotion", + ) + if synthErr != nil { + return xerrors.Errorf("insert synthetic tool results during promote-driven cleanup: %w", synthErr) + } + result.syntheticToolResults = inserted + var promoteErr error + result.promotedMessage, result.remainingQueuedMessages, result.shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(ctx, tx, latestChat) + if promoteErr != nil { + logger.Error(ctx, "auto-promote queued message failed during promote-driven cleanup", slog.Error(promoteErr)) + return xerrors.Errorf("auto-promote queued message: %w", promoteErr) + } + if result.promotedMessage != nil { + status = database.ChatStatusPending + } else { + // Queue drained between snapshot and lock; honor + // the external Waiting. + status = database.ChatStatusWaiting + } case status == database.ChatStatusWaiting && !latestChat.Archived: // Queued messages were already admitted through SendMessage, // so auto-promotion only preserves FIFO order here. Archived @@ -5464,6 +5574,10 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { remainingQueuedMessages = finishResult.remainingQueuedMessages shouldPublishQueueUpdate = finishResult.shouldPublishQueueUpdate + // Publish synth rows before the promoted user message. + for _, msg := range finishResult.syntheticToolResults { + p.publishMessage(chat.ID, msg) + } if promotedMessage != nil { p.publishMessage(chat.ID, *promotedMessage) } @@ -8032,7 +8146,7 @@ func formatPlanPathBlock(chatPath, home string) string { } func (p *Server) recoverStaleChats(ctx context.Context) { - staleAfter := time.Now().Add(-p.inFlightChatStaleAfter) + staleAfter := p.clock.Now().Add(-p.inFlightChatStaleAfter) staleChats, err := p.db.GetStaleChats(ctx, staleAfter) if err != nil { p.logger.Error(ctx, "failed to get stale chats", slog.Error(err)) @@ -8074,6 +8188,14 @@ func (p *Server) recoverStaleChats(ctx context.Context) { slog.F("chat_id", chat.ID)) return nil } + case database.ChatStatusWaiting: + // Deferred-promote stranding: worker died before its + // post-cancel cleanup ran. Re-check freshness. + if !locked.UpdatedAt.Before(staleAfter) { + p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery", + slog.F("chat_id", chat.ID)) + return nil + } default: // Status changed since our snapshot; skip. p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery", @@ -8113,7 +8235,7 @@ func (p *Server) recoverStaleChats(ctx context.Context) { // so the LLM history remains valid if the user // retries the chat later. if locked.Status == database.ChatStatusRequiresAction { - if synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil { + if _, synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil { p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery", slog.F("chat_id", chat.ID), slog.Error(synthErr), @@ -8123,6 +8245,25 @@ func (p *Server) recoverStaleChats(ctx context.Context) { } } + if locked.Status == database.ChatStatusWaiting { + // Close pending dynamic tool calls; otherwise the + // promoted user message would feed the LLM a turn it + // rejects. Propagate errors so the next recovery + // tick retries instead of promoting incomplete + // history. + if _, synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by queued message promotion"); synthErr != nil { + return xerrors.Errorf("insert synthetic tool results during stale recovery: %w", synthErr) + } + promoted, _, _, promoteErr := p.tryAutoPromoteQueuedMessage(ctx, tx, locked) + if promoteErr != nil { + return xerrors.Errorf("auto-promote during stale recovery: %w", promoteErr) + } + if promoted == nil { + // Empty queue means nothing to recover. + return nil + } + } + // Reset so any replica can pick it up (pending) or // the client sees the failure (error). _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ @@ -8150,37 +8291,66 @@ func (p *Server) recoverStaleChats(ctx context.Context) { } } -// insertSyntheticToolResultsTx inserts error tool-result messages for -// every pending dynamic tool call in the last assistant message. This -// keeps the LLM message history valid (every tool-call has a matching -// tool-result) when a requires_action chat times out or is interrupted. -// It operates on the provided store, which may be a transaction handle. +// insertSyntheticToolResultsTx inserts IsError tool-result messages +// for unresolved dynamic tool calls in the last assistant message, +// skipping calls already handled (e.g. by chatloop dispatching a +// name-colliding dynamic tool as a built-in). It operates on the +// provided store, which may be a transaction handle. func insertSyntheticToolResultsTx( ctx context.Context, store database.Store, chat database.Chat, reason string, -) error { +) ([]database.ChatMessage, error) { dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) if err != nil { - return xerrors.Errorf("parse dynamic tools: %w", err) + return nil, xerrors.Errorf("parse dynamic tools: %w", err) } if len(dynamicToolNames) == 0 { - return nil + return nil, nil } - // Get the last assistant message to find pending tool calls. + // No assistant means nothing to close: a deferred promote can + // race a worker that fails before any persist, and the cleanup + // TX must still advance. lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ ChatID: chat.ID, Role: database.ChatMessageRoleAssistant, }) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } if err != nil { - return xerrors.Errorf("get last assistant message: %w", err) + return nil, xerrors.Errorf("get last assistant message: %w", err) } parts, err := chatprompt.ParseContent(lastAssistant) if err != nil { - return xerrors.Errorf("parse assistant message: %w", err) + return nil, xerrors.Errorf("parse assistant message: %w", err) + } + + // Mirrors SubmitToolResults. + afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: lastAssistant.ID, + }) + if err != nil { + return nil, xerrors.Errorf("get messages after assistant: %w", err) + } + handledCallIDs := make(map[string]bool) + for _, msg := range afterMsgs { + if msg.Role != database.ChatMessageRoleTool { + continue + } + msgParts, err := chatprompt.ParseContent(msg) + if err != nil { + continue + } + for _, mp := range msgParts { + if mp.Type == codersdk.ChatMessagePartTypeToolResult { + handledCallIDs[mp.ToolCallID] = true + } + } } // Collect dynamic tool calls that need synthetic results. @@ -8189,6 +8359,9 @@ func insertSyntheticToolResultsTx( if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] { continue } + if handledCallIDs[part.ToolCallID] { + continue + } resultPart := codersdk.ChatMessagePart{ Type: codersdk.ChatMessagePartTypeToolResult, ToolCallID: part.ToolCallID, @@ -8198,13 +8371,13 @@ func insertSyntheticToolResultsTx( } marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart}) if marshalErr != nil { - return xerrors.Errorf("marshal synthetic tool result: %w", marshalErr) + return nil, xerrors.Errorf("marshal synthetic tool result: %w", marshalErr) } resultContents = append(resultContents, marshaled) } if len(resultContents) == 0 { - return nil + return nil, nil } // Insert tool-result messages using the same pattern as @@ -8238,11 +8411,12 @@ func insertSyntheticToolResultsTx( params.ContentVersion[i] = chatprompt.CurrentContentVersion params.Visibility[i] = database.ChatMessageVisibilityBoth } - if _, err := store.InsertChatMessages(ctx, params); err != nil { - return xerrors.Errorf("insert synthetic tool results: %w", err) + inserted, err := store.InsertChatMessages(ctx, params) + if err != nil { + return nil, xerrors.Errorf("insert synthetic tool results: %w", err) } - return nil + return inserted, nil } // parseDynamicToolNames unmarshals the dynamic tools JSON column diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index d160bbe8d1..1f667ab1b7 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -8767,6 +8767,411 @@ func TestPromoteQueuedRejectsArchivedChat(t *testing.T) { require.ErrorIs(t, err, chatd.ErrChatArchived) } +// TestPromoteQueuedWhileRequiresAction guards against the +// stops-dead failure mode: promoting on requires_action without +// closing pending dynamic tool calls leaves the assistant turn +// with unresolved tool_call parts that the LLM API rejects. It +// also asserts the synthetic tool-result row is published to live +// SSE subscribers before the promoted user message. +func TestPromoteQueuedWhileRequiresAction(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("requires-action-promote") + } + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk( + "my_dynamic_tool", + `{"input":"hello"}`, + ), + ) + } + // Second call: the resumed run after promote completes. + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Resumed after promotion.")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + server := newActiveTestServer(t, db, ps) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "promote-while-requires-action", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Please call the dynamic tool."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + var chatBeforePromote database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatBeforePromote = got + return got.Status == database.ChatStatusRequiresAction || + got.Status == database.ChatStatusError + }, testutil.IntervalFast) + require.Equal(t, database.ChatStatusRequiresAction, chatBeforePromote.Status, + "expected requires_action, got %s (last_error=%q)", + chatBeforePromote.Status, chatLastErrorMessage(chatBeforePromote.LastError)) + + var pendingToolCallID string + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleAssistant { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" { + pendingToolCallID = part.ToolCallID + return true + } + } + } + return false + }, testutil.IntervalFast) + require.NotEmpty(t, pendingToolCallID, "expected pending dynamic tool call") + + queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedResult.Queued) + require.NotNil(t, queuedResult.QueuedMessage) + + // Subscribe before promoting to capture published events. + _, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + defer subCancel() + promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedResult.QueuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.Equal(t, database.ChatMessageRoleUser, promoteResult.PromotedMessage.Role) + + // Synthetic row must publish before the promoted user message. + var ( + syntheticPublishedAt int + userPublishedAt int + messagesSeen int + ) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + select { + case ev := <-events: + if ev.Type != codersdk.ChatStreamEventTypeMessage || ev.Message == nil { + return false + } + messagesSeen++ + switch ev.Message.Role { + case codersdk.ChatMessageRoleTool: + if syntheticPublishedAt == 0 { + syntheticPublishedAt = messagesSeen + } + case codersdk.ChatMessageRoleUser: + if ev.Message.ID == promoteResult.PromotedMessage.ID { + userPublishedAt = messagesSeen + } + } + return syntheticPublishedAt > 0 && userPublishedAt > 0 + default: + return false + } + }, testutil.IntervalFast) + require.Less(t, syntheticPublishedAt, userPublishedAt, + "synthetic tool-result must be published before the promoted user message") + + queuedAfter, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queuedAfter, "queued message should be removed after sync promotion") + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + syntheticToolResult *database.ChatMessage + promotedUserMessage *database.ChatMessage + ) + for i := range messages { + msg := messages[i] + if msg.Role == database.ChatMessageRoleTool { + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + if part.ToolCallID != pendingToolCallID { + continue + } + require.True(t, part.IsError, + "synthetic tool result should have IsError=true") + syntheticToolResult = &messages[i] + } + } + if msg.ID == promoteResult.PromotedMessage.ID { + promotedUserMessage = &messages[i] + } + } + require.NotNil(t, syntheticToolResult, + "expected a synthetic error tool result for the pending tool call") + require.NotNil(t, promotedUserMessage) + require.Less(t, syntheticToolResult.ID, promotedUserMessage.ID, + "synthetic tool result must precede the promoted user message") + + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.IntervalFast) + final, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, final.Status, + "chat should resume to waiting after promotion (last_error=%q)", + chatLastErrorMessage(final.LastError)) +} + +// TestPromoteQueuedWhileRequiresActionMixedTools guards against +// duplicating already-resolved built-in tool results: synthetic +// results must be scoped to dynamic tool names only. +func TestPromoteQueuedWhileRequiresActionMixedTools(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("mixed-tools-promote") + } + if streamedCallCount.Add(1) == 1 { + builtinChunk := chattest.OpenAIToolCallChunk( + "read_file", + `{"path":"/tmp/test.txt"}`, + ) + dynamicChunk := chattest.OpenAIToolCallChunk( + "my_dynamic_tool", + `{"input":"hello world"}`, + ) + mergedChunk := builtinChunk + dynCall := dynamicChunk.Choices[0].ToolCalls[0] + dynCall.Index = 1 + mergedChunk.Choices[0].ToolCalls = append( + mergedChunk.Choices[0].ToolCalls, + dynCall, + ) + return chattest.OpenAIStreamingResponse(mergedChunk) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Resumed after mixed-tool promotion.")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + server := newActiveTestServer(t, db, ps) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "promote-while-requires-action-mixed", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Call both tools."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + var chatBeforePromote database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatBeforePromote = got + return got.Status == database.ChatStatusRequiresAction || + got.Status == database.ChatStatusError + }, testutil.IntervalFast) + require.Equal(t, database.ChatStatusRequiresAction, chatBeforePromote.Status, + "expected requires_action, got %s (last_error=%q)", + chatBeforePromote.Status, chatLastErrorMessage(chatBeforePromote.LastError)) + + // The built-in tool resolves before requires_action; capture + // its row ID to assert the dynamic synthetic comes after. + var ( + dynamicToolCallID string + builtinToolResultID int64 + builtinToolResultSeen bool + ) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" { + builtinToolResultID = msg.ID + builtinToolResultSeen = true + } + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" { + dynamicToolCallID = part.ToolCallID + } + } + } + return builtinToolResultSeen && dynamicToolCallID != "" + }, testutil.IntervalFast) + require.NotEmpty(t, dynamicToolCallID) + require.NotZero(t, builtinToolResultID) + + queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedResult.Queued) + require.NotNil(t, queuedResult.QueuedMessage) + + _, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + defer subCancel() + promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedResult.QueuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.NotZero(t, promoteResult.PromotedMessage.ID, + "requires_action promotion is synchronous and returns the inserted message") + + // Only the dynamic tool's synth row publishes; the built-in's + // pre-existing result is not republished. + var ( + syntheticPublishCount int + userPublished bool + ) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + select { + case ev := <-events: + if ev.Type != codersdk.ChatStreamEventTypeMessage || ev.Message == nil { + return false + } + switch ev.Message.Role { + case codersdk.ChatMessageRoleTool: + syntheticPublishCount++ + case codersdk.ChatMessageRoleUser: + if ev.Message.ID == promoteResult.PromotedMessage.ID { + userPublished = true + } + } + return userPublished + default: + return false + } + }, testutil.IntervalFast) + require.Equal(t, 1, syntheticPublishCount, + "only the dynamic tool's synthetic result must be published; the built-in's pre-existing result must not be republished") + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + dynamicSyntheticCount int + builtinResultsForReadFile int + ) + for _, msg := range messages { + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + switch part.ToolName { + case "read_file": + builtinResultsForReadFile++ + case "my_dynamic_tool": + if part.IsError && part.ToolCallID == dynamicToolCallID && msg.ID > builtinToolResultID { + dynamicSyntheticCount++ + } + } + } + } + require.Equal(t, 1, dynamicSyntheticCount, + "expected exactly one synthetic error tool result for the dynamic tool call") + require.Equal(t, 1, builtinResultsForReadFile, + "built-in tool result should not be duplicated by promotion") + + require.Greater(t, promoteResult.PromotedMessage.ID, builtinToolResultID) +} + func TestSubmitToolResultsRejectsArchivedChat(t *testing.T) { t.Parallel() @@ -9650,3 +10055,1103 @@ func seedAdvisorConfig( ) require.NoError(t, err) } + +// TestPromoteQueuedWhileRunning guards against the data-loss +// failure mode: promoting on a streaming chat must preserve +// partial assistant output by deferring the user-message insert +// to the worker's auto-promote. +func TestPromoteQueuedWhileRunning(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + streamStarted := make(chan struct{}) + streamCanceled := make(chan struct{}) + var streamCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("running-promote") + } + if streamCallCount.Add(1) > 1 { + // Subsequent calls are the resumed run; let it settle. + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("resumed after promotion")..., + ) + } + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("partial-running-output")[0] + select { + case <-streamStarted: + default: + close(streamStarted) + } + <-req.Context().Done() + select { + case <-streamCanceled: + default: + close(streamCanceled) + } + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + }) + + server := newActiveTestServer(t, db, ps) + user, org, model := seedChatDependencies(t, db) + setOpenAIProviderBaseURL(ctx, t, db, openAIURL) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + OrganizationID: org.ID, + Title: "promote-while-running", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + fromDB, dbErr := db.GetChatByID(ctx, chat.ID) + if dbErr != nil { + return false + } + return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid + }, testutil.IntervalFast) + + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + select { + case <-streamStarted: + return true + default: + return false + } + }, testutil.IntervalFast) + + queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedResult.Queued) + require.NotNil(t, queuedResult.QueuedMessage) + + promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedResult.QueuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + // Deferred promotion: no synchronous user message. + require.Zero(t, promoteResult.PromotedMessage.ID) + + // Worker observes waiting and cancels. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + select { + case <-streamCanceled: + return true + default: + return false + } + }, testutil.IntervalFast) + + // Partial assistant output is preserved (not lost as it was + // pre-fix) and precedes the promoted user message. Poll on the + // messages themselves: the status passes through Waiting + // transiently before finishActiveChat's external-Waiting case + // promotes the queued message and flips the chat to Pending. + // Both messages being persisted implies cleanup completed. + var ( + partialAssistantID int64 + promotedUserID int64 + ) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if err != nil { + return false + } + var ( + assistantID int64 + userID int64 + ) + for _, msg := range messages { + switch msg.Role { + case database.ChatMessageRoleAssistant: + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "partial-running-output") { + assistantID = msg.ID + } + } + case database.ChatMessageRoleUser: + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "promote me") { + userID = msg.ID + } + } + } + } + if assistantID == 0 || userID == 0 { + return false + } + partialAssistantID = assistantID + promotedUserID = userID + return true + }, testutil.IntervalFast) + require.Less(t, partialAssistantID, promotedUserID, + "promoted user message must follow the persisted partial output") +} + +// TestPromoteQueuedWhileRunningRespectsMessageOrder guards +// against losing or reshuffling sibling queued messages when one +// is promoted out-of-order. +func TestPromoteQueuedWhileRunningRespectsMessageOrder(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + streamStarted := make(chan struct{}) + var streamCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("running-promote-order") + } + if streamCallCount.Add(1) > 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("resumed")..., + ) + } + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("partial")[0] + select { + case <-streamStarted: + default: + close(streamStarted) + } + <-req.Context().Done() + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + }) + + server := newActiveTestServer(t, db, ps) + user, org, model := seedChatDependencies(t, db) + setOpenAIProviderBaseURL(ctx, t, db, openAIURL) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + OrganizationID: org.ID, + Title: "promote-while-running-order", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + fromDB, dbErr := db.GetChatByID(ctx, chat.ID) + if dbErr != nil { + return false + } + return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid + }, testutil.IntervalFast) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + select { + case <-streamStarted: + return true + default: + return false + } + }, testutil.IntervalFast) + + queueA, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("A")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.NotNil(t, queueA.QueuedMessage) + queueB, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("B")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.NotNil(t, queueB.QueuedMessage) + queueC, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("C")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.NotNil(t, queueC.QueuedMessage) + + promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queueB.QueuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.Zero(t, promoteResult.PromotedMessage.ID, + "running-case promotion is deferred to auto-promote") + + // PromoteQueued reorders to [B, A, C]. IDs are stable because + // only created_at is mutated. + queuedAfterPromote, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, queuedAfterPromote, 3) + require.Equal(t, queueB.QueuedMessage.ID, queuedAfterPromote[0].ID, + "promoted message must be first in the queue") + require.Equal(t, queueA.QueuedMessage.ID, queuedAfterPromote[1].ID, + "non-promoted messages preserve their relative order") + require.Equal(t, queueC.QueuedMessage.ID, queuedAfterPromote[2].ID, + "non-promoted messages preserve their relative order") + + // Poll for B in history rather than asserting the queue + // state, which races the worker's auto-promote pipeline. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, getErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if getErr != nil { + return false + } + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleUser { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + return false + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "B" { + return true + } + } + } + return false + }, testutil.IntervalFast, + "the promoted message B must appear in chat history") + + // A and C must end up in queue or history, not dropped. + remainingIDs := map[int64]bool{} + remainingQueue, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + for _, qm := range remainingQueue { + remainingIDs[qm.ID] = true + } + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + promotedTexts := map[string]bool{} + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleUser { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText { + promotedTexts[part.Text] = true + } + } + } + require.True(t, remainingIDs[queueA.QueuedMessage.ID] || promotedTexts["A"], + "message A must not be lost") + require.True(t, remainingIDs[queueC.QueuedMessage.ID] || promotedTexts["C"], + "message C must not be lost") +} + +// TestFinishActiveChatExternalWaitingInsertsSyntheticResults +// asserts the cleanup TX inserts synthetic tool-result rows when +// PromoteQueued's deferred path set Status=Waiting while the +// worker concluded with RequiresAction. Without it, the next +// chatloop run would feed the LLM an assistant turn with +// unresolved tool_call parts and the API would reject it. +func TestFinishActiveChatExternalWaitingInsertsSyntheticResults(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + server := newActiveTestServer(t, db, ps) + user, org, model := seedChatDependencies(t, db) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + }}) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "external-waiting-stops-dead-guard", + LastModelConfigID: model.ID, + DynamicTools: nullRawMessage(dynamicToolsJSON), + }) + require.NoError(t, err) + + // Seed a user message and an assistant message with an + // unresolved dynamic tool call. This mirrors what the worker + // would have persisted before the deferred promote arrived. + insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input") + + pendingCallID := "call_pending_dynamic" + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: pendingCallID, + ToolName: "my_dynamic_tool", + Args: json.RawMessage(`{}`), + }, + }) + require.NoError(t, err) + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(assistantContent.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + + // Queue a message and put the chat in the post-promote + // Waiting state (no worker, queue at front). + queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("queued-after-promote"), + }) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent.RawMessage, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }) + require.NoError(t, err) + + // Refresh chat with current status (Waiting, no worker). + latestChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + + // Drive the cleanup path with the local-RequiresAction outcome. + updated, promoted, syntheticToolResults, finishErr := chatd.FinishActiveChatForTest( + ctx, server, latestChat, database.ChatStatusRequiresAction, "", + ) + require.NoError(t, finishErr) + require.NotNil(t, promoted, "queued message must be auto-promoted into history") + require.Equal(t, database.ChatStatusPending, updated.Status, + "chat must end Pending so the run loop picks it up") + require.Len(t, syntheticToolResults, 1, + "cleanup TX must return the inserted synthetic tool-result row so the post-TX caller can publish it") + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + assistantIdx = -1 + synthToolIdx = -1 + promotedUserIdx = -1 + ) + for i, msg := range messages { + switch msg.Role { + case database.ChatMessageRoleAssistant: + assistantIdx = i + case database.ChatMessageRoleTool: + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && + part.ToolCallID == pendingCallID && part.IsError { + synthToolIdx = i + } + } + case database.ChatMessageRoleUser: + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && + part.Text == "queued-after-promote" { + promotedUserIdx = i + } + } + } + } + require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present") + require.NotEqual(t, -1, synthToolIdx, + "synthetic tool result for the unresolved dynamic tool call must be inserted") + require.NotEqual(t, -1, promotedUserIdx, + "promoted queued message must be inserted as a user message") + require.Less(t, assistantIdx, synthToolIdx, + "synthetic tool result must follow the assistant message") + require.Less(t, synthToolIdx, promotedUserIdx, + "promoted user message must follow the synthetic tool result") +} + +// TestPromoteQueuedFallsThroughOnStaleHeartbeat asserts a stale +// heartbeat takes the synchronous path so the chat does not strand +// in Waiting waiting on a worker that will not return. +func TestPromoteQueuedFallsThroughOnStaleHeartbeat(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + staleAfter := 100 * time.Millisecond + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: testutil.WaitLong, + InFlightChatStaleAfter: staleAfter, + }) + t.Cleanup(func() { require.NoError(t, server.Close()) }) + + user, org, model := seedChatDependencies(t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "stale-heartbeat-promote-fallthrough", + LastModelConfigID: model.ID, + }) + require.NoError(t, err) + + // Place the chat in Running with a stale heartbeat. We do not + // start the server's run loop, so no worker will ever pick this + // chat up; the test isolates the fall-through decision in + // PromoteQueued. + deadWorker := uuid.New() + staleTime := time.Now().Add(-2 * staleAfter) + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: deadWorker, Valid: true}, + StartedAt: sql.NullTime{Time: staleTime, Valid: true}, + HeartbeatAt: sql.NullTime{Time: staleTime, Valid: true}, + }) + require.NoError(t, err) + + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + require.NotNil(t, queued.QueuedMessage) + + result, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queued.QueuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.NotZero(t, result.PromotedMessage.ID, + "stale heartbeat must take the synchronous path and insert a user message inline") + + got, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, got.Status, + "synchronous promote ends Pending") + require.False(t, got.WorkerID.Valid, + "worker_id is cleared by the synchronous promote") +} + +// TestRecoverStaleChatsRecoversWaitingWithQueue asserts a Waiting +// chat with a non-empty queue and stale updated_at gets recovered +// to Pending, closing the post-promote-stranding hole. +func TestRecoverStaleChatsRecoversWaitingWithQueue(t *testing.T) { + t.Parallel() + + db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + staleAfter := 100 * time.Millisecond + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: testutil.WaitLong, + InFlightChatStaleAfter: staleAfter, + }) + t.Cleanup(func() { require.NoError(t, server.Close()) }) + user, org, model := seedChatDependencies(t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "stale-waiting-with-queue", + LastModelConfigID: model.ID, + }) + require.NoError(t, err) + + queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("queued-stranded"), + }) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent.RawMessage, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }) + require.NoError(t, err) + // Backdate updated_at directly so the chat is past the stale + // threshold without sleeping. + _, err = rawDB.ExecContext(ctx, + "UPDATE chats SET updated_at = $1 WHERE id = $2", + time.Now().Add(-time.Hour), chat.ID) + require.NoError(t, err) + + chatd.RecoverStaleChatsForTest(ctx, server) + + got, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, got.Status, + "stale-recovery must promote the front-of-queue and set Pending") + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + var foundPromoted bool + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleUser { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && + part.Text == "queued-stranded" { + foundPromoted = true + } + } + } + require.True(t, foundPromoted, + "the front-of-queue message must be promoted into history") + + remaining, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, remaining, + "the queue is drained after the recovery promotes its only entry") +} + +// TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults +// asserts stale recovery closes pending dynamic tool calls before +// promoting, so the recovery path does not stop the chat dead by +// feeding the LLM unresolved tool_call parts. +func TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults(t *testing.T) { + t.Parallel() + + db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + staleAfter := 100 * time.Millisecond + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: testutil.WaitLong, + InFlightChatStaleAfter: staleAfter, + }) + t.Cleanup(func() { require.NoError(t, server.Close()) }) + + user, org, model := seedChatDependencies(t, db) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + }}) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "stale-waiting-with-unresolved-tool-call", + LastModelConfigID: model.ID, + DynamicTools: nullRawMessage(dynamicToolsJSON), + }) + require.NoError(t, err) + + insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call the tool") + + pendingCallID := "call_unresolved_dynamic" + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: pendingCallID, + ToolName: "my_dynamic_tool", + Args: json.RawMessage(`{}`), + }, + }) + require.NoError(t, err) + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(assistantContent.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + + queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("queued-after-crash"), + }) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent.RawMessage, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }) + require.NoError(t, err) + + _, err = rawDB.ExecContext(ctx, + "UPDATE chats SET updated_at = $1 WHERE id = $2", + time.Now().Add(-time.Hour), chat.ID) + require.NoError(t, err) + + chatd.RecoverStaleChatsForTest(ctx, server) + + got, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, got.Status) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + assistantIdx = -1 + synthIdx = -1 + promotedUserIdx = -1 + ) + for i, msg := range messages { + switch msg.Role { + case database.ChatMessageRoleAssistant: + assistantIdx = i + case database.ChatMessageRoleTool: + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && + part.ToolCallID == pendingCallID && part.IsError { + synthIdx = i + } + } + case database.ChatMessageRoleUser: + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && + part.Text == "queued-after-crash" { + promotedUserIdx = i + } + } + } + } + require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present") + require.NotEqual(t, -1, synthIdx, + "stale recovery must insert synthetic tool result for the unresolved dynamic tool call") + require.NotEqual(t, -1, promotedUserIdx, + "queued message must be promoted into history") + require.Less(t, assistantIdx, synthIdx) + require.Less(t, synthIdx, promotedUserIdx) +} + +// TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls asserts +// the helper skips tool calls already handled (e.g. when a dynamic +// tool name collides with a built-in the chatloop dispatched). +// Without dedup the LLM would see two results for the same call ID. +func TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + user, org, model := seedChatDependencies(t, db) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{ + { + Name: "duplicate_call_tool", + Description: "Tool whose call already has a result.", + InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, + }, + { + Name: "still_pending_tool", + Description: "Tool whose call has no result yet.", + InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, + }, + }) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusRequiresAction, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "synth-results-dedup", + LastModelConfigID: model.ID, + DynamicTools: nullRawMessage(dynamicToolsJSON), + }) + require.NoError(t, err) + + insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call both tools") + + handledCallID := "call_already_handled" + pendingCallID := "call_still_pending" + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: handledCallID, + ToolName: "duplicate_call_tool", + Args: json.RawMessage(`{}`), + }, + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: pendingCallID, + ToolName: "still_pending_tool", + Args: json.RawMessage(`{}`), + }, + }) + require.NoError(t, err) + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(assistantContent.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + + // Pre-insert a tool-result for the handled call ID. This + // simulates the chatloop having dispatched the colliding + // dynamic tool name as a built-in. + handledResultContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: handledCallID, + ToolName: "duplicate_call_tool", + Result: json.RawMessage(`"already done"`), + }, + }) + require.NoError(t, err) + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleTool}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(handledResultContent.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + + chatRow, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + + _, err = chatd.InsertSyntheticToolResultsTxForTest( + ctx, db, chatRow, "synth reason", + ) + require.NoError(t, err) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + handledCount int + pendingCount int + syntheticForPending bool + ) + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleTool { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + switch part.ToolCallID { + case handledCallID: + handledCount++ + case pendingCallID: + pendingCount++ + if part.IsError { + syntheticForPending = true + } + } + } + } + require.Equal(t, 1, handledCount, + "handled call must keep exactly one tool result") + require.Equal(t, 1, pendingCount, + "pending call must get exactly one synthetic tool result") + require.True(t, syntheticForPending, + "the new tool result for the pending call must be marked IsError") +} + +// nullRawMessage wraps raw JSON in a NullRawMessage. An empty input +// becomes the zero value (Valid=false). +func nullRawMessage(raw []byte) pqtype.NullRawMessage { + if len(raw) == 0 { + return pqtype.NullRawMessage{} + } + return pqtype.NullRawMessage{RawMessage: raw, Valid: true} +} + +// TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage +// asserts the helper short-circuits cleanly when no assistant +// message exists yet, so a deferred promote racing a worker that +// fails before any persist does not roll back the cleanup TX. +func TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + user, org, model := seedChatDependencies(t, db) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, + }}) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "no-assistant-message", + LastModelConfigID: model.ID, + DynamicTools: nullRawMessage(dynamicToolsJSON), + }) + require.NoError(t, err) + + // No assistant message persisted. The helper must return nil so + // the caller's transaction can still advance. + _, err = chatd.InsertSyntheticToolResultsTxForTest( + ctx, db, chat, "no assistant", + ) + require.NoError(t, err) +} + +// TestRecoverStaleChatsWaitingPropagatesSynthError asserts stale +// recovery rolls back when synth-result insertion fails, leaving +// the chat Waiting for the next tick instead of promoting on top +// of incomplete history. +func TestRecoverStaleChatsWaitingPropagatesSynthError(t *testing.T) { + t.Parallel() + + db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + staleAfter := 100 * time.Millisecond + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: testutil.WaitLong, + InFlightChatStaleAfter: staleAfter, + }) + t.Cleanup(func() { require.NoError(t, server.Close()) }) + + user, org, model := seedChatDependencies(t, db) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, + }}) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + Title: "stale-waiting-synth-error", + LastModelConfigID: model.ID, + DynamicTools: nullRawMessage(dynamicToolsJSON), + }) + require.NoError(t, err) + + insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input") + + // Inject a synth-results error via an unsupported + // ContentVersion: the row is valid JSON so the insert + // succeeds, but chatprompt.ParseContent rejects it inside the + // helper. Brittle if a future migration adds a content_version + // CHECK constraint; switch to a mock store at that point. + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{99}, + Content: []string{`{}`}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + + queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("queued-not-promoted-on-synth-error"), + }) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent.RawMessage, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }) + require.NoError(t, err) + + _, err = rawDB.ExecContext(ctx, + "UPDATE chats SET updated_at = $1 WHERE id = $2", + time.Now().Add(-time.Hour), chat.ID) + require.NoError(t, err) + + chatd.RecoverStaleChatsForTest(ctx, server) + + got, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, got.Status, + "recovery must leave the chat in Waiting when synth-results fails so the next tick retries") + + // The queued message must still be in the queue, not promoted. + remaining, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, remaining, 1, + "queued message must not be promoted when synth-results fails") + + // No promoted user message should appear in history. + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleUser { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + require.NotEqual(t, "queued-not-promoted-on-synth-error", part.Text, + "queued message must not be promoted when synth-results fails") + } + } +} diff --git a/coderd/x/chatd/export_test.go b/coderd/x/chatd/export_test.go index 7c7177b88b..60e00038b7 100644 --- a/coderd/x/chatd/export_test.go +++ b/coderd/x/chatd/export_test.go @@ -1,5 +1,15 @@ package chatd +import ( + "context" + + "github.com/sqlc-dev/pqtype" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" +) + // WaitUntilIdleForTest waits for background chat work tracked by the server to // finish without shutting the server down. Tests use this to assert final // database state only after asynchronous chat processing has completed. @@ -7,3 +17,54 @@ package chatd func WaitUntilIdleForTest(server *Server) { server.drainInflight() } + +// FinishActiveChatForTest exposes the unexported cleanup TX so tests +// can drive the post-run state machine deterministically. Returns the +// resulting chat, the promoted message (if any), the synthetic +// tool-result rows the cleanup TX inserted (if any), and the cleanup +// error. The lastError string is encoded into a structured payload +// the same way runChat does, so callers do not need to know about +// the structured-error wrapper. +func FinishActiveChatForTest( + ctx context.Context, + server *Server, + chat database.Chat, + status database.ChatStatus, + lastError string, +) (database.Chat, *database.ChatMessage, []database.ChatMessage, error) { + logger := server.logger.With(slog.F("chat_id", chat.ID)) + var encoded pqtype.NullRawMessage + if lastError != "" { + var err error + encoded, err = encodeChatLastErrorPayload(&codersdk.ChatError{ + Message: lastError, + }) + if err != nil { + return database.Chat{}, nil, nil, err + } + } + result, err := server.finishActiveChat(ctx, logger, chat, status, encoded) + if err != nil { + return database.Chat{}, nil, nil, err + } + return result.updatedChat, result.promotedMessage, result.syntheticToolResults, nil +} + +// RecoverStaleChatsForTest exposes the unexported stale-recovery loop +// so tests can assert the recovery state machine without waiting for +// the periodic ticker. +func RecoverStaleChatsForTest(ctx context.Context, server *Server) { + server.recoverStaleChats(ctx) +} + +// InsertSyntheticToolResultsTxForTest exposes the unexported helper +// so tests can verify the dedup path against pre-existing tool +// results. +func InsertSyntheticToolResultsTxForTest( + ctx context.Context, + store database.Store, + chat database.Chat, + reason string, +) ([]database.ChatMessage, error) { + return insertSyntheticToolResultsTx(ctx, store, chat, reason) +} diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 0e4c777e42..0ea9204654 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3207,11 +3207,10 @@ class ExperimentalApiMethods { promoteChatQueuedMessage = async ( chatId: string, queuedMessageId: number, - ): Promise => { - const response = await this.axios.post( + ): Promise => { + await this.axios.post( `/api/experimental/chats/${chatId}/queue/${queuedMessageId}/promote`, ); - return response.data; }; getChatDiffContents = async ( diff --git a/site/src/pages/AgentsPage/AgentChatPage.test.ts b/site/src/pages/AgentsPage/AgentChatPage.test.ts index e19fc7a3bf..0306703fcb 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.test.ts +++ b/site/src/pages/AgentsPage/AgentChatPage.test.ts @@ -1,6 +1,7 @@ import { act, renderHook } from "@testing-library/react"; import { createRef } from "react"; import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { ChatQueuedMessage } from "#/api/typesGenerated"; import { clearPersistedSidebarTabId, draftInputStorageKeyPrefix, @@ -8,6 +9,7 @@ import { getPersistedSidebarTabId, lastActiveSidebarTabStorageKeyPrefix, restoreOptimisticRequestSnapshot, + runPromoteQueuedMessage, savePersistedSidebarTabId, submitEditAndScroll, useConversationEditingState, @@ -181,6 +183,78 @@ describe("restoreOptimisticRequestSnapshot", () => { }); }); +describe("runPromoteQueuedMessage", () => { + const makeQueuedMessage = (id: number, text: string, chatID = "chat-1") => + ({ + id, + chat_id: chatID, + created_at: "2025-01-01T00:00:00Z", + content: [{ type: "text", text }], + }) as ChatQueuedMessage; + + it("suppresses the promoted ID and removes it optimistically", async () => { + const store = createChatStore(); + const a = makeQueuedMessage(1, "A"); + const b = makeQueuedMessage(2, "B"); + const c = makeQueuedMessage(3, "C"); + store.setQueuedMessages([a, b, c]); + store.setChatStatus("running"); + + const promote = vi.fn(async (_id: number) => undefined); + const clearChatErrorReason = vi.fn(); + const handleUsageLimitError = vi.fn(); + + await runPromoteQueuedMessage({ + id: b.id, + store, + promoteQueuedMessage: promote, + agentId: "chat-1", + clearChatErrorReason, + handleUsageLimitError, + }); + + expect(promote).toHaveBeenCalledWith(b.id); + + const snapshot = store.getSnapshot(); + expect(snapshot.queuedMessages.map((m) => m.id)).toEqual([a.id, c.id]); + expect(snapshot.suppressedQueuedMessageIDs.has(b.id)).toBe(true); + expect(snapshot.chatStatus).toBe("pending"); + }); + + it("rolls back queue and status, clears suppression, and rethrows on API error", async () => { + const store = createChatStore(); + const a = makeQueuedMessage(1, "A"); + const b = makeQueuedMessage(2, "B"); + store.setQueuedMessages([a, b]); + store.setChatStatus("waiting"); + + const apiError = new Error("boom"); + const promote = vi.fn(async (_id: number) => { + throw apiError; + }); + const clearChatErrorReason = vi.fn(); + const handleUsageLimitError = vi.fn(); + + await expect( + runPromoteQueuedMessage({ + id: b.id, + store, + promoteQueuedMessage: promote, + agentId: "chat-1", + clearChatErrorReason, + handleUsageLimitError, + }), + ).rejects.toBe(apiError); + + expect(handleUsageLimitError).toHaveBeenCalledWith(apiError); + + const snapshot = store.getSnapshot(); + expect(snapshot.queuedMessages.map((m) => m.id)).toEqual([a.id, b.id]); + expect(snapshot.chatStatus).toBe("waiting"); + expect(snapshot.suppressedQueuedMessageIDs.has(b.id)).toBe(false); + }); +}); + describe("useConversationEditingState", () => { const chatID = "chat-abc-123"; const expectedKey = `${draftInputStorageKeyPrefix}${chatID}`; diff --git a/site/src/pages/AgentsPage/AgentChatPage.tsx b/site/src/pages/AgentsPage/AgentChatPage.tsx index b1cbc45bbd..a01233f2fb 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.tsx @@ -191,6 +191,68 @@ export const restoreOptimisticRequestSnapshot = ( }); }; +/** + * Runs the optimistic queued-message promotion flow. + * + * The promote endpoint returns 202 Accepted with no message body, so the + * actual user message is delivered via SSE or the messages REST endpoint. + * Suppress the promoted ID so the transient reordered queue published by + * the running-case backend does not flash the message back into the + * visible queue. Roll back queue, status, and suppression on API error. + * + * @internal Exported for testing. + */ +export const runPromoteQueuedMessage = async (params: { + id: number; + store: Pick< + ChatStore, + | "batch" + | "clearStreamError" + | "clearStreamState" + | "getSnapshot" + | "setChatStatus" + | "setQueuedMessages" + | "setStreamError" + | "setStreamState" + | "suppressQueuedMessageID" + | "unsuppressQueuedMessageID" + >; + promoteQueuedMessage: (id: number) => Promise; + agentId: string | undefined; + clearChatErrorReason: (chatID: string) => void; + handleUsageLimitError: (error: unknown) => void; +}): Promise => { + const { + id, + store, + promoteQueuedMessage, + agentId, + clearChatErrorReason, + handleUsageLimitError, + } = params; + const previousSnapshot = store.getSnapshot(); + store.batch(() => { + store.suppressQueuedMessageID(id); + store.setQueuedMessages( + previousSnapshot.queuedMessages.filter((message) => message.id !== id), + ); + store.clearStreamState(); + store.clearStreamError(); + store.setChatStatus("pending"); + }); + if (agentId) { + clearChatErrorReason(agentId); + } + try { + await promoteQueuedMessage(id); + } catch (error) { + store.unsuppressQueuedMessageID(id); + restoreOptimisticRequestSnapshot(store, previousSnapshot); + handleUsageLimitError(error); + throw error; + } +}; + export async function submitEditAndScroll({ editMessage, editArgs, @@ -1139,30 +1201,15 @@ const AgentChatPage: FC = () => { } }; - const handlePromoteQueuedMessage = async (id: number) => { - const previousSnapshot = store.getSnapshot(); - store.setQueuedMessages( - previousSnapshot.queuedMessages.filter((message) => message.id !== id), - ); - store.clearStreamState(); - if (agentId) { - clearChatErrorReason(agentId); - } - store.clearStreamError(); - store.setChatStatus("pending"); - try { - const promotedMessage = await promoteQueuedMessage(id); - // Insert the promoted message into the store and cache - // immediately so it appears in the timeline without - // waiting for the WebSocket to deliver it. - store.upsertDurableMessage(promotedMessage); - upsertCacheMessages([promotedMessage]); - } catch (error) { - restoreOptimisticRequestSnapshot(store, previousSnapshot); - handleUsageLimitError(error); - throw error; - } - }; + const handlePromoteQueuedMessage = (id: number) => + runPromoteQueuedMessage({ + id, + store, + promoteQueuedMessage, + agentId, + clearChatErrorReason, + handleUsageLimitError, + }); const editing = useConversationEditingState({ chatID: agentId, diff --git a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.createStore.test.ts b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.createStore.test.ts index 2c06b4eb27..f812922201 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.createStore.test.ts +++ b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.createStore.test.ts @@ -424,6 +424,74 @@ describe("setQueuedMessages", () => { }); }); +// --------------------------------------------------------------------------- +// suppressQueuedMessageID / applyAuthoritativeQueuedMessages +// --------------------------------------------------------------------------- + +describe("suppressQueuedMessageID / applyAuthoritativeQueuedMessages", () => { + it("filters suppressed IDs from authoritative writes and auto-clears", () => { + const store = createChatStore(); + const a = makeQueuedMessage(1, "A"); + const b = makeQueuedMessage(2, "B"); + const c = makeQueuedMessage(3, "C"); + + store.setQueuedMessages([a, b, c]); + store.suppressQueuedMessageID(b.id); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(b.id)).toBe(true); + + // Transient reordered queue from the running-case backend + // must not surface the suppressed message. + store.applyAuthoritativeQueuedMessages([b, a, c]); + expect( + store.getSnapshot().queuedMessages.map((message) => message.id), + ).toEqual([a.id, c.id]); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(b.id)).toBe(true); + + store.applyAuthoritativeQueuedMessages([a, c]); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(b.id)).toBe( + false, + ); + expect( + store.getSnapshot().queuedMessages.map((message) => message.id), + ).toEqual([a.id, c.id]); + }); + + it("filters suppressed IDs from REST hydration via applyAuthoritativeQueuedMessages", () => { + const store = createChatStore(); + const a = makeQueuedMessage(1, "A"); + const b = makeQueuedMessage(2, "B"); + const c = makeQueuedMessage(3, "C"); + + store.suppressQueuedMessageID(b.id); + // REST hydration delivers the unfiltered queue [B, A, C]. + store.applyAuthoritativeQueuedMessages([b, a, c]); + expect( + store.getSnapshot().queuedMessages.map((message) => message.id), + ).toEqual([a.id, c.id]); + }); + + it("unsuppressQueuedMessageID removes IDs from the suppression set", () => { + const store = createChatStore(); + store.suppressQueuedMessageID(42); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(42)).toBe(true); + store.unsuppressQueuedMessageID(42); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(42)).toBe(false); + }); + + it("setQueuedMessages does not auto-clear suppression", () => { + const store = createChatStore(); + const a = makeQueuedMessage(1, "A"); + + store.suppressQueuedMessageID(99); + // setQueuedMessages is the optimistic path: it must not + // touch the suppression set, otherwise the optimistic write + // would lift suppression before the authoritative reordered + // queue arrives. + store.setQueuedMessages([a]); + expect(store.getSnapshot().suppressedQueuedMessageIDs.has(99)).toBe(true); + }); +}); + // --------------------------------------------------------------------------- // clearStreamState // --------------------------------------------------------------------------- diff --git a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.ts b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.ts index 972da499dd..1191b3d6ca 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.ts +++ b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.ts @@ -153,6 +153,11 @@ export type ChatStoreState = { retryState: RetryState | null; reconnectState: ReconnectState | null; queuedMessages: readonly TypesGen.ChatQueuedMessage[]; + // Hides queued IDs from the visible queue while the backend is + // in a transient state that would briefly include them. Used by + // the running-case promote, where the backend reorders the + // queued message to the front before auto-promoting it. + suppressedQueuedMessageIDs: ReadonlySet; subagentStatusOverrides: Map; }; @@ -173,6 +178,16 @@ export type ChatStore = { setQueuedMessages: ( queuedMessages: readonly TypesGen.ChatQueuedMessage[] | undefined, ) => void; + // Server-truthful queue snapshot, filtered through the + // suppression set. Use for SSE queue_update and REST hydration; + // optimistic writes go through setQueuedMessages so they don't + // lift suppression. + applyAuthoritativeQueuedMessages: ( + queuedMessages: readonly TypesGen.ChatQueuedMessage[] | undefined, + ) => void; + suppressQueuedMessageID: (id: number) => void; + unsuppressQueuedMessageID: (id: number) => void; + clearSuppressedQueuedMessageIDs: () => void; setChatStatus: (status: TypesGen.ChatStatus | null) => void; setStreamState: (streamState: StreamState | null) => void; setStreamError: (reason: ChatDetailError | null) => void; @@ -199,6 +214,7 @@ const createInitialState = (): ChatStoreState => ({ retryState: null, reconnectState: null, queuedMessages: [], + suppressedQueuedMessageIDs: new Set(), subagentStatusOverrides: new Map(), }); @@ -404,6 +420,73 @@ export const createChatStore = (): ChatStore => { return { ...current, queuedMessages: nextQueuedMessages }; }); }, + applyAuthoritativeQueuedMessages: (queuedMessages) => { + const incoming = queuedMessages ?? []; + setState((current) => { + let nextSuppressed = current.suppressedQueuedMessageIDs; + if (current.suppressedQueuedMessageIDs.size > 0) { + const incomingIDs = new Set(incoming.map((message) => message.id)); + let copy: Set | null = null; + for (const id of current.suppressedQueuedMessageIDs) { + if (!incomingIDs.has(id)) { + if (!copy) { + copy = new Set(current.suppressedQueuedMessageIDs); + } + copy.delete(id); + } + } + if (copy) { + nextSuppressed = copy; + } + } + const filtered = + nextSuppressed.size === 0 + ? incoming + : incoming.filter((message) => !nextSuppressed.has(message.id)); + const sameQueue = chatQueuedMessagesEqualByID( + current.queuedMessages, + filtered, + ); + const sameSuppressed = + nextSuppressed === current.suppressedQueuedMessageIDs; + if (sameQueue && sameSuppressed) { + return current; + } + return { + ...current, + queuedMessages: sameQueue ? current.queuedMessages : filtered, + suppressedQueuedMessageIDs: nextSuppressed, + }; + }); + }, + suppressQueuedMessageID: (id) => { + setState((current) => { + if (current.suppressedQueuedMessageIDs.has(id)) { + return current; + } + const next = new Set(current.suppressedQueuedMessageIDs); + next.add(id); + return { ...current, suppressedQueuedMessageIDs: next }; + }); + }, + unsuppressQueuedMessageID: (id) => { + setState((current) => { + if (!current.suppressedQueuedMessageIDs.has(id)) { + return current; + } + const next = new Set(current.suppressedQueuedMessageIDs); + next.delete(id); + return { ...current, suppressedQueuedMessageIDs: next }; + }); + }, + clearSuppressedQueuedMessageIDs: () => { + setState((current) => { + if (current.suppressedQueuedMessageIDs.size === 0) { + return current; + } + return { ...current, suppressedQueuedMessageIDs: new Set() }; + }); + }, setChatStatus: (status) => { if (state.chatStatus === status) { return; diff --git a/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts b/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts index 79d7de28b8..b1fd4be1c4 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts +++ b/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts @@ -237,6 +237,10 @@ export const useChatStore = ( wsQueueUpdateReceivedRef.current = false; wsStatusReceivedRef.current = false; store.setQueuedMessages([]); + // Suppression entries are scoped to the current chat; clear + // them on chat change so a stale promote suppression doesn't + // hide queued messages in another chat. + store.clearSuppressedQueuedMessageIDs(); if (!chatID) { return; } @@ -258,7 +262,7 @@ export const useChatStore = ( return; } queuedMessagesHydratedChatIDRef.current = chatID; - store.setQueuedMessages(chatQueuedMessages); + store.applyAuthoritativeQueuedMessages(chatQueuedMessages); }, [chatMessagesData, chatID, chatQueuedMessages, store]); useEffect(() => { @@ -473,7 +477,9 @@ export const useChatStore = ( continue; } wsQueueUpdateReceivedRef.current = true; - store.setQueuedMessages(streamEvent.queued_messages); + store.applyAuthoritativeQueuedMessages( + streamEvent.queued_messages, + ); updateChatQueuedMessages(streamEvent.queued_messages); continue; case "status": {