From c7cac9debe10039e3fdb7c9d37c2cdd1fa34f7ad Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:36:08 +0200 Subject: [PATCH] fix: persist per-turn model on chats and queued messages (#24688) Previously, `chats.last_model_config_id` was not updated when a user sent a mid-chat message with a different model, and queued messages did not store their own per-turn model, so promotion ran against whatever the chat row said at promote time. Chat watch events also did not merge `last_model_config_id` into the site's root, child, and per-chat caches, so sidebar labels stayed stale after direct sends and queued promotions. - Add nullable `chat_queued_messages.model_config_id`, backfilled from `chats.last_model_config_id`. Queued inserts round-trip the effective model id at enqueue time. - In `coderd/x/chatd`, direct sends update `chats.last_model_config_id` inside the same transaction that inserts the admitted user message. Manual promotion and auto-promotion use the queued row's stored `model_config_id`, with a fallback to `chats.last_model_config_id` for legacy NULL rows during rollout. `PromoteQueuedOptions.ModelConfigID` is now ignored. - On the site, extract `mergeWatchedChatSummary` and `mergeWatchedChatIntoCaches` in `site/src/api/queries/chats.ts` so status-change watch events merge `last_model_config_id` into the root infinite chat list, the parent-embedded child entry, and the per-chat `chatKey(chatId)` cache. `updated_at` guards against stale watch payloads clobbering newer cached state, while diff status events still merge their PR metadata because they are timestamped outside the chat row. Watch timestamps are compared as instants so variable fractional precision does not make fresh events look stale. - Queued promotion validates stored model config IDs before admission. Invalid legacy queued IDs fall back to the chat's current model config instead of dropping the queued message during auto-promotion. - Backend and frontend regression coverage added for admission, queue promotion (including FIFO across mixed models, legacy NULL fallback, and invalid queued model IDs), and chat watch cache merging. > Mux is acting on Mike's behalf. --- coderd/database/db2sdk/db2sdk.go | 9 +- coderd/database/dump.sql | 3 +- ..._chat_queued_message_model_config.down.sql | 2 + ...77_chat_queued_message_model_config.up.sql | 8 + coderd/database/models.go | 9 +- coderd/database/queries.sql.go | 24 +- coderd/database/queries/chats.sql | 8 +- coderd/exp_chats.go | 13 +- coderd/exp_chats_test.go | 323 ++++++++++++ coderd/x/chatd/chatd.go | 164 +++++- coderd/x/chatd/chatd_test.go | 490 ++++++++++++++++++ codersdk/chats.go | 9 +- site/src/api/queries/chats.test.ts | 378 ++++++++++++++ site/src/api/queries/chats.ts | 174 +++++++ site/src/api/typesGenerated.ts | 1 + site/src/pages/AgentsPage/AgentsPage.tsx | 147 +----- 16 files changed, 1580 insertions(+), 182 deletions(-) create mode 100644 coderd/database/migrations/000477_chat_queued_message_model_config.down.sql create mode 100644 coderd/database/migrations/000477_chat_queued_message_model_config.up.sql diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 49cac265d3..1db74da36f 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1517,10 +1517,11 @@ func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMe } return codersdk.ChatQueuedMessage{ - ID: message.ID, - ChatID: message.ChatID, - Content: parts, - CreatedAt: message.CreatedAt, + ID: message.ID, + ChatID: message.ChatID, + ModelConfigID: nullUUIDPtr(message.ModelConfigID), + Content: parts, + CreatedAt: message.CreatedAt, } } diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 464433656f..4bba9d9cb3 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1422,7 +1422,8 @@ CREATE TABLE chat_queued_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, content jsonb NOT NULL, - created_at timestamp with time zone DEFAULT now() NOT NULL + created_at timestamp with time zone DEFAULT now() NOT NULL, + model_config_id uuid ); CREATE SEQUENCE chat_queued_messages_id_seq diff --git a/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql b/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql new file mode 100644 index 0000000000..aa655e7a9c --- /dev/null +++ b/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE chat_queued_messages +DROP COLUMN model_config_id; diff --git a/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql b/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql new file mode 100644 index 0000000000..fb4fc16410 --- /dev/null +++ b/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql @@ -0,0 +1,8 @@ +ALTER TABLE chat_queued_messages +ADD COLUMN model_config_id uuid; + +UPDATE chat_queued_messages AS cqm +SET model_config_id = chats.last_model_config_id +FROM chats +WHERE chats.id = cqm.chat_id + AND cqm.model_config_id IS NULL; diff --git a/coderd/database/models.go b/coderd/database/models.go index d5ffc05e82..e406ee9e1c 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4509,10 +4509,11 @@ type ChatProvider struct { } type ChatQueuedMessage struct { - ID int64 `db:"id" json:"id"` - ChatID uuid.UUID `db:"chat_id" json:"chat_id"` - Content json.RawMessage `db:"content" json:"content"` - CreatedAt time.Time `db:"created_at" json:"created_at"` + ID int64 `db:"id" json:"id"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Content json.RawMessage `db:"content" json:"content"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` } type ChatUsageLimitConfig struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f909dab25e..33ddec0052 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -6753,7 +6753,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get } const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many -SELECT id, chat_id, content, created_at FROM chat_queued_messages +SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages WHERE chat_id = $1 ORDER BY id ASC ` @@ -6772,6 +6772,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ); err != nil { return nil, err } @@ -7642,24 +7643,30 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa } const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content) -VALUES ($1, $2) -RETURNING id, chat_id, content, created_at +INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +VALUES ( + $1, + $2, + $3::uuid +) +RETURNING id, chat_id, content, created_at, model_config_id ` type InsertChatQueuedMessageParams struct { - ChatID uuid.UUID `db:"chat_id" json:"chat_id"` - Content json.RawMessage `db:"content" json:"content"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Content json.RawMessage `db:"content" json:"content"` + ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` } func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) { - row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content) + row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content, arg.ModelConfigID) var i ChatQueuedMessage err := row.Scan( &i.ID, &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ) return i, err } @@ -7884,7 +7891,7 @@ WHERE id = ( ORDER BY cqm.id ASC LIMIT 1 ) -RETURNING id, chat_id, content, created_at +RETURNING id, chat_id, content, created_at, model_config_id ` func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error) { @@ -7895,6 +7902,7 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index e88f0abc41..09bdda356a 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -882,8 +882,12 @@ RETURNING *; -- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content) -VALUES (@chat_id, @content) +INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +VALUES ( + @chat_id, + @content, + sqlc.narg('model_config_id')::uuid +) RETURNING *; -- name: GetChatQueuedMessages :many diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index fd31ff9e16..172d63fd3f 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -2531,13 +2531,18 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } + modelConfigID := uuid.Nil + if req.ModelConfigID != nil { + modelConfigID = *req.ModelConfigID + } + sendResult, sendErr := api.chatDaemon.SendMessage( ctx, chatd.SendMessageOptions{ ChatID: chatID, CreatedBy: apiKey.UserID, Content: contentBlocks, - ModelConfigID: req.ModelConfigID, + ModelConfigID: modelConfigID, BusyBehavior: busyBehavior, PlanMode: sendPlanMode, MCPServerIDs: req.MCPServerIDs, @@ -2560,6 +2565,12 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { }) return } + if xerrors.Is(sendErr, chatd.ErrInvalidModelConfigID) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid model config ID.", + }) + return + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to create chat message.", Detail: sendErr.Error(), diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 5d93be2e6a..5a7dcfe619 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -5870,6 +5870,329 @@ func TestPostChatMessages(t *testing.T) { }) } +func waitForChatWatchStatusChangeEvent( + ctx context.Context, + t *testing.T, + conn *websocket.Conn, + chatID uuid.UUID, +) codersdk.ChatWatchEvent { + t.Helper() + + for { + var payload codersdk.ChatWatchEvent + err := wsjson.Read(ctx, conn, &payload) + require.NoError(t, err) + if payload.Kind == codersdk.ChatWatchEventKindStatusChange && payload.Chat.ID == chatID { + return payload + } + } +} + +func TestSendMessageWithModelOverrideUpdatesLastModelConfigID(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-override-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "mid-chat model switch direct send", + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "switch to model b", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + }) + require.NoError(t, err) + require.False(t, resp.Queued) + require.NotNil(t, resp.Message) + require.NotNil(t, resp.Message.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.Message.ModelConfigID) + + storedChat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(dbauthz.AsSystemRestricted(ctx), database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + require.True(t, messages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, messages[0].ModelConfigID.UUID) +} + +func TestSendMessageQueuesEffectiveModelConfigID(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-queued-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "mid-chat model switch queued send", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue this with model b", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, resp.Queued) + require.NotNil(t, resp.QueuedMessage) + require.NotNil(t, resp.QueuedMessage.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.QueuedMessage.ModelConfigID) + + queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queuedMessages, 1) + require.True(t, queuedMessages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, queuedMessages[0].ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigA.ID, storedChat.LastModelConfigID) +} + +func TestQueuedMessageWithoutOverrideCapturesEnqueueTimeModel(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-later-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "capture queued enqueue-time model", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue with stored model", + }}, + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, resp.Queued) + require.NotNil(t, resp.QueuedMessage) + require.NotNil(t, resp.QueuedMessage.ModelConfigID) + require.Equal(t, modelConfigA.ID, *resp.QueuedMessage.ModelConfigID) + + _, err = db.UpdateChatLastModelConfigByID(dbauthz.AsSystemRestricted(ctx), database.UpdateChatLastModelConfigByIDParams{ + ID: chat.ID, + LastModelConfigID: modelConfigB.ID, + }) + require.NoError(t, err) + + queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queuedMessages, 1) + require.True(t, queuedMessages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigA.ID, queuedMessages[0].ModelConfigID.UUID) +} + +func TestSubsequentSendWithoutOverrideUsesPersistedModel(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-persisted-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigB.ID, + Title: "subsequent send uses persisted model", + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "reuse the persisted model", + }}, + }) + require.NoError(t, err) + require.False(t, resp.Queued) + require.NotNil(t, resp.Message) + require.NotNil(t, resp.Message.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.Message.ModelConfigID) + + messages, err := db.GetChatMessagesByChatID(dbauthz.AsSystemRestricted(ctx), database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + require.True(t, messages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, messages[0].ModelConfigID.UUID) +} + +func TestWatchChatsStatusChangeCarriesUpdatedLastModelConfigID(t *testing.T) { + t.Parallel() + + t.Run("DirectSend", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-watch-direct-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "watch direct model switch", + }) + require.NoError(t, err) + + conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) + require.NoError(t, err) + defer conn.Close(websocket.StatusNormalClosure, "done") + + _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "watch the direct send override", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + }) + require.NoError(t, err) + + event := waitForChatWatchStatusChangeEvent(ctx, t, conn, chat.ID) + require.Equal(t, modelConfigB.ID, event.Chat.LastModelConfigID) + }) + + t.Run("QueuedPromotion", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-watch-promote-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "watch queued promotion model switch", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + queuedResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue the promoted model override", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedResp.Queued) + require.NotNil(t, queuedResp.QueuedMessage) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) + require.NoError(t, err) + defer conn.Close(websocket.StatusNormalClosure, "done") + + promoteRes, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedResp.QueuedMessage.ID), + nil, + ) + require.NoError(t, err) + defer promoteRes.Body.Close() + require.Equal(t, http.StatusOK, promoteRes.StatusCode) + + event := waitForChatWatchStatusChangeEvent(ctx, t, conn, chat.ID) + require.Equal(t, modelConfigB.ID, event.Chat.LastModelConfigID) + }) +} + func TestChatMessageWithFileReferences(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index bf6ddcd4ee..fe9c206279 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -871,6 +871,8 @@ func (c *streamStateCollector) Collect(ch chan<- prometheus.Metric) { const MaxQueueSize = 20 var ( + // ErrInvalidModelConfigID indicates the requested model config does not exist. + ErrInvalidModelConfigID = xerrors.New("invalid model config ID") // ErrMessageQueueFull indicates the per-chat queue limit was reached. ErrMessageQueueFull = xerrors.New("chat message queue is full") // ErrEditedMessageNotFound indicates the edited message does not exist @@ -950,7 +952,7 @@ type SendMessageOptions struct { ChatID uuid.UUID CreatedBy uuid.UUID Content []codersdk.ChatMessagePart - ModelConfigID *uuid.UUID + ModelConfigID uuid.UUID BusyBehavior SendMessageBusyBehavior PlanMode *database.NullChatPlanMode MCPServerIDs *[]uuid.UUID @@ -983,7 +985,6 @@ type PromoteQueuedOptions struct { ChatID uuid.UUID CreatedBy uuid.UUID QueuedMessageID int64 - ModelConfigID *uuid.UUID } // PromoteQueuedResult contains post-promotion message metadata. @@ -1217,9 +1218,14 @@ func (p *Server) SendMessage( } } - modelConfigID := lockedChat.LastModelConfigID - if opts.ModelConfigID != nil { - modelConfigID = *opts.ModelConfigID + modelConfigID, err := resolveSendMessageModelConfigID( + ctx, + tx, + lockedChat, + opts.ModelConfigID, + ) + if err != nil { + return err } // Update MCP server IDs on the chat when explicitly provided. @@ -1264,6 +1270,10 @@ func (p *Server) SendMessage( queued, err := tx.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ ChatID: opts.ChatID, Content: content.RawMessage, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigID, + Valid: modelConfigID != uuid.Nil, + }, }) if err != nil { return xerrors.Errorf("insert queued message: %w", err) @@ -1368,6 +1378,90 @@ func (p *Server) checkUsageLimit(ctx context.Context, store database.Store, owne return nil } +func chatdModelConfigLookupContext(ctx context.Context) context.Context { + //nolint:gocritic // Chat message admission needs daemon-scoped + // deployment-config reads for model config validation. + return dbauthz.AsChatd(ctx) +} + +func resolveSendMessageModelConfigID( + ctx context.Context, + store database.Store, + chat database.Chat, + requested uuid.UUID, +) (uuid.UUID, error) { + if requested == uuid.Nil { + return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID) + } + + chatdCtx := chatdModelConfigLookupContext(ctx) + if _, err := store.GetChatModelConfigByID(chatdCtx, requested); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "%w: %s", + ErrInvalidModelConfigID, + requested, + ) + } + return uuid.Nil, xerrors.Errorf( + "get requested model config %s: %w", + requested, + err, + ) + } + return requested, nil +} + +func resolveQueuedMessageModelConfigID( + ctx context.Context, + store database.Store, + chat database.Chat, + queuedModelConfigID uuid.NullUUID, +) (uuid.UUID, error) { + chatdCtx := chatdModelConfigLookupContext(ctx) + if queuedModelConfigID.Valid && queuedModelConfigID.UUID != uuid.Nil { + if _, err := store.GetChatModelConfigByID(chatdCtx, queuedModelConfigID.UUID); err == nil { + return queuedModelConfigID.UUID, nil + } else if !errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "get queued model config %s: %w", + queuedModelConfigID.UUID, + err, + ) + } + } + + return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID) +} + +func resolveFallbackModelConfigID( + ctx context.Context, + store database.Store, + modelConfigID uuid.UUID, +) (uuid.UUID, error) { + chatdCtx := chatdModelConfigLookupContext(ctx) + if modelConfigID != uuid.Nil { + if _, err := store.GetChatModelConfigByID(chatdCtx, modelConfigID); err == nil { + return modelConfigID, nil + } else if !errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "get chat model config %s: %w", + modelConfigID, + err, + ) + } + } + + defaultConfig, err := store.GetDefaultChatModelConfig(chatdCtx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.New("no default chat model config is available") + } + return uuid.Nil, xerrors.Errorf("get default chat model config: %w", err) + } + return defaultConfig.ID, nil +} + // EditMessage marks the old user message as deleted, soft-deletes all // following messages, inserts a new message with the updated content, // clears queued messages, and moves the chat into pending status. @@ -1768,23 +1862,20 @@ func (p *Server) PromoteQueued( return ErrChatArchived } - modelConfigID := lockedChat.LastModelConfigID - if opts.ModelConfigID != nil { - modelConfigID = *opts.ModelConfigID - } - queuedMessages, err := tx.GetChatQueuedMessages(ctx, opts.ChatID) if err != nil { return xerrors.Errorf("get queued messages: %w", err) } var ( - targetContent json.RawMessage - found bool + targetContent json.RawMessage + targetModelConfigID uuid.NullUUID + found bool ) for _, qm := range queuedMessages { if qm.ID == opts.QueuedMessageID { targetContent = qm.Content + targetModelConfigID = qm.ModelConfigID found = true break } @@ -1793,6 +1884,16 @@ func (p *Server) PromoteQueued( return xerrors.New("queued message not found") } + effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( + ctx, + tx, + lockedChat, + targetModelConfigID, + ) + if err != nil { + return err + } + err = tx.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{ ID: opts.QueuedMessageID, ChatID: opts.ChatID, @@ -1805,7 +1906,7 @@ func (p *Server) PromoteQueued( ctx, tx, lockedChat, - modelConfigID, + effectiveModelConfigID, pqtype.NullRawMessage{ RawMessage: targetContent, Valid: len(targetContent) > 0, @@ -3313,6 +3414,8 @@ func BuildSingleChatMessageInsertParams( return params } +// insertUserMessageAndSetPending inserts a user message, transitions the +// chat to pending when needed, and returns the refreshed chat row. func insertUserMessageAndSetPending( ctx context.Context, store database.Store, @@ -3338,7 +3441,16 @@ func insertUserMessageAndSetPending( message := messages[0] if lockedChat.Status == database.ChatStatusPending { - return message, lockedChat, nil + if modelConfigID == uuid.Nil || lockedChat.LastModelConfigID == modelConfigID { + return message, lockedChat, nil + } + // The InsertChatMessages CTE updates chats.last_model_config_id when + // the message's model config differs. Reload to surface that change. + updatedChat, err := store.GetChatByID(ctx, lockedChat.ID) + if err != nil { + return database.ChatMessage{}, database.Chat{}, xerrors.Errorf("get chat after model config update: %w", err) + } + return message, updatedChat, nil } updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ @@ -4752,13 +4864,31 @@ func (p *Server) tryAutoPromoteQueuedMessage( ) (*database.ChatMessage, []database.ChatQueuedMessage, bool, error) { logger := p.logger.With(slog.F("chat_id", chat.ID)) - nextQueued, err := tx.PopNextQueuedMessage(ctx, chat.ID) - if errors.Is(err, sql.ErrNoRows) { + queuedMessages, err := tx.GetChatQueuedMessages(ctx, chat.ID) + if err != nil { + return nil, nil, false, xerrors.Errorf("get queued messages: %w", err) + } + if len(queuedMessages) == 0 { return nil, nil, false, nil } + nextQueued := queuedMessages[0] + effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( + ctx, + tx, + chat, + nextQueued.ModelConfigID, + ) + if err != nil { + return nil, nil, false, err + } + + poppedQueued, err := tx.PopNextQueuedMessage(ctx, chat.ID) if err != nil { return nil, nil, false, xerrors.Errorf("pop next queued message: %w", err) } + if poppedQueued.ID != nextQueued.ID { + return nil, nil, false, xerrors.New("popped queued message out of order") + } msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. ChatID: chat.ID, @@ -4770,7 +4900,7 @@ func (p *Server) tryAutoPromoteQueuedMessage( Valid: len(nextQueued.Content) > 0, }, database.ChatMessageVisibilityBoth, - chat.LastModelConfigID, + effectiveModelConfigID, chatprompt.CurrentContentVersion, ).withCreatedBy(chat.OwnerID)) msgs, err := insertChatMessageWithStore(ctx, tx, msgParams) diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 7f545974c3..beb1216315 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/workspacestats" @@ -2045,6 +2046,38 @@ func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) { require.Len(t, messages, 1) } +func TestSendMessageRejectsInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfig := seedChatDependencies(ctx, t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusPending, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfig.ID, + Title: "reject invalid queued model config", + }) + require.NoError(t, err) + + invalidModelConfigID := uuid.New() + _, err = replica.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")}, + ModelConfigID: invalidModelConfigID, + }) + require.ErrorIs(t, err, chatd.ErrInvalidModelConfigID) + + queued, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queued) +} + func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) { t.Parallel() @@ -2501,6 +2534,463 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing require.Equal(t, database.ChatMessageRoleUser, messages[3].Role) } +func TestPromoteQueuedMessageUsesQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai", + "gpt-4o-mini-promote-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued uses stored model", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with model b")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigB.ID, + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + require.Equal(t, database.ChatStatusPending, storedChat.Status) +} + +func TestPromoteQueuedMessageReloadsChatWhenModelConfigChangesDuringPending(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai", + "gpt-4o-mini-promote-pending-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + watchEvents := make(chan struct { + payload codersdk.ChatWatchEvent + err error + }, 1) + cancelWatch, err := ps.SubscribeWithErr( + coderdpubsub.ChatWatchEventChannel(user.ID), + coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) { + select { + case watchEvents <- struct { + payload codersdk.ChatWatchEvent + err error + }{payload: payload, err: err}: + default: + } + }), + ) + require.NoError(t, err) + defer cancelWatch() + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusPending, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued reloads pending chat", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with new model")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigB.ID, + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, storedChat.Status) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + + select { + case event := <-watchEvents: + require.NoError(t, event.err) + require.Equal(t, codersdk.ChatWatchEventKindStatusChange, event.payload.Kind) + require.Equal(t, chat.ID, event.payload.Chat.ID) + require.Equal(t, codersdk.ChatStatusPending, event.payload.Chat.Status) + require.Equal(t, modelConfigB.ID, event.payload.Chat.LastModelConfigID) + case <-ctx.Done(): + t.Fatal("timed out waiting for status change watch event") + } +} + +func TestAutoPromoteQueuedMessagesPreservesPerTurnModelOrder(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + firstRunStarted := make(chan struct{}) + allowFirstRunFinish := make(chan struct{}) + var requestCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + switch requestCount.Add(1) { + case 1: + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("first run partial")[0] + select { + case <-firstRunStarted: + default: + close(firstRunStarted) + } + <-allowFirstRunFinish + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + case 2: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("second run done")...) + case 3: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("third run done")...) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("extra run done")...) + } + }) + + server := newActiveTestServer(t, db, ps) + user, org, modelConfigA := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini-queue-b-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + modelConfigC := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini-queue-c-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "auto-promote per-turn model order", + ModelConfigID: modelConfigA.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, firstRunStarted) + + queuedB, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued b")}, + ModelConfigID: modelConfigB.ID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedB.Queued) + + queuedC, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued c")}, + ModelConfigID: modelConfigC.ID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedC.Queued) + + close(allowFirstRunFinish) + + require.Eventually(t, func() bool { + return requestCount.Load() >= 3 + }, testutil.WaitLong, testutil.IntervalFast) + chatd.WaitUntilIdleForTest(server) + + queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queuedMessages) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, storedChat.Status) + require.Equal(t, modelConfigC.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var userTexts []string + var userModelConfigIDs []uuid.UUID + for _, message := range messages { + if message.Role != database.ChatMessageRoleUser { + continue + } + sdkMessage := db2sdk.ChatMessage(message) + require.Len(t, sdkMessage.Content, 1) + userTexts = append(userTexts, sdkMessage.Content[0].Text) + require.True(t, message.ModelConfigID.Valid) + userModelConfigIDs = append(userModelConfigIDs, message.ModelConfigID.UUID) + } + require.Equal(t, []string{"hello", "queued b", "queued c"}, userTexts) + require.Equal(t, []uuid.UUID{modelConfigA.ID, modelConfigB.ID, modelConfigC.ID}, userModelConfigIDs) +} + +func TestAutoPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) { + t.Parallel() + + testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{}) +} + +func TestAutoPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }) +} + +func testAutoPromoteQueuedMessageFallback(t *testing.T, queuedModelConfigID uuid.NullUUID) { + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + firstRunStarted := make(chan struct{}) + allowFirstRunFinish := make(chan struct{}) + var requestCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + switch requestCount.Add(1) { + case 1: + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("first run partial")[0] + select { + case <-firstRunStarted: + default: + close(firstRunStarted) + } + <-allowFirstRunFinish + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("fallback run done")...) + } + }) + + server := newActiveTestServer(t, db, ps) + user, org, modelConfig := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "auto-promote queued fallback", + ModelConfigID: modelConfig.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, firstRunStarted) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")}) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: queuedModelConfigID, + }) + require.NoError(t, err) + + close(allowFirstRunFinish) + + require.Eventually(t, func() bool { + return requestCount.Load() >= 2 + }, testutil.WaitLong, testutil.IntervalFast) + chatd.WaitUntilIdleForTest(server) + + queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queuedMessages) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, storedChat.Status) + require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var found bool + for _, message := range messages { + if message.Role != database.ChatMessageRoleUser { + continue + } + sdkMessage := db2sdk.ChatMessage(message) + require.Len(t, sdkMessage.Content, 1) + if sdkMessage.Content[0].Text != "legacy queued row" { + continue + } + require.True(t, message.ModelConfigID.Valid) + require.Equal(t, modelConfig.ID, message.ModelConfigID.UUID) + found = true + } + require.True(t, found) +} + +func TestPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued legacy fallback", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigA.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigA.ID, storedChat.LastModelConfigID) +} + +func TestPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfig := seedChatDependencies(ctx, t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfig.ID, + Title: "promote queued invalid fallback", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("invalid queued model")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfig.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID) +} + func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) { t.Parallel() diff --git a/codersdk/chats.go b/codersdk/chats.go index fe66e77a00..6751878850 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1263,10 +1263,11 @@ const ( // ChatQueuedMessage represents a queued message waiting to be processed. type ChatQueuedMessage struct { - ID int64 `json:"id"` - ChatID uuid.UUID `json:"chat_id" format:"uuid"` - Content []ChatMessagePart `json:"content"` - CreatedAt time.Time `json:"created_at" format:"date-time"` + ID int64 `json:"id"` + ChatID uuid.UUID `json:"chat_id" format:"uuid"` + ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` + Content []ChatMessagePart `json:"content"` + CreatedAt time.Time `json:"created_at" format:"date-time"` } // ChatStreamMessagePart is a streamed message part update. diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index 2501b600f1..54eb576569 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -25,6 +25,8 @@ import { infiniteChats, interruptChat, invalidateChatListQueries, + mergeWatchedChatIntoCaches, + mergeWatchedChatSummary, paginatedChatCostUsers, pinChat, promoteChatQueuedMessage, @@ -1892,6 +1894,382 @@ describe("updateChildInParentCache", () => { }); }); +describe("mergeWatchedChatSummary", () => { + it("merges fresh status updates without clobbering a newer title snapshot", () => { + const cachedChat = makeChat("chat-1", { + status: "pending", + title: "Fresh title", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "running", + title: "Stale title", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("merges last_model_config_id when watched updated_at equals cached updated_at", () => { + const cachedChat = makeChat("chat-1", { + last_model_config_id: "11111111-1111-4111-8111-111111111111", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + last_model_config_id: "22222222-2222-4222-8222-222222222222", + updated_at: "2025-01-01T00:00:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }).last_model_config_id, + ).toBe("22222222-2222-4222-8222-222222222222"); + }); + + it("compares updated_at values as instants instead of strings", () => { + const cachedChat = makeChat("chat-1", { + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.12Z", + }); + const watchedChat = makeChat("chat-1", { + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:00:00.1203Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:00:00.1203Z", + }); + }); + + it("merges fresh title updates without clobbering a newer status snapshot", () => { + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Updated title", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "title_change", + }), + ).toMatchObject({ + status: "running", + title: "Updated title", + }); + }); + + it("merges title updates even when chat updated_at is older", () => { + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + updated_at: "2025-01-01T00:10:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Newer generated title", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "title_change", + }), + ).toMatchObject({ + status: "running", + title: "Newer generated title", + updated_at: "2025-01-01T00:10:00.000Z", + }); + }); + + it("merges fresh diff status updates without clobbering status or title", () => { + const cachedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/1", + pull_request_state: "open", + pull_request_title: "Old title", + pull_request_draft: false, + changes_requested: false, + additions: 1, + deletions: 2, + changed_files: 3, + refreshed_at: "2025-01-01T00:00:00.000Z", + stale_at: "2025-01-01T01:00:00.000Z", + }; + const watchedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/2", + pull_request_state: "merged", + pull_request_title: "New title", + pull_request_draft: false, + changes_requested: true, + additions: 4, + deletions: 5, + changed_files: 6, + refreshed_at: "2025-01-01T00:05:00.000Z", + stale_at: "2025-01-01T01:05:00.000Z", + }; + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + diff_status: cachedDiffStatus, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Stale title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "diff_status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + diff_status: watchedDiffStatus, + }); + }); + + it("merges diff status updates even when chat updated_at is older", () => { + const cachedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/1", + pull_request_state: "open", + pull_request_title: "Old title", + pull_request_draft: false, + changes_requested: false, + additions: 1, + deletions: 2, + changed_files: 3, + refreshed_at: "2025-01-01T00:00:00.000Z", + stale_at: "2025-01-01T01:00:00.000Z", + }; + const watchedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/2", + pull_request_state: "open", + pull_request_title: "New title", + pull_request_draft: true, + changes_requested: true, + additions: 4, + deletions: 5, + changed_files: 6, + refreshed_at: "2025-01-01T00:10:00.000Z", + stale_at: "2025-01-01T01:10:00.000Z", + }; + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + diff_status: cachedDiffStatus, + updated_at: "2025-01-01T00:10:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Stale title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "diff_status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:10:00.000Z", + }); + }); + + it("marks other chats unread on fresh status updates", () => { + const cachedChat = makeChat("chat-1", { + has_unread: false, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + activeChatId: "chat-2", + }).has_unread, + ).toBe(true); + }); + + it("preserves has_unread for the active chat", () => { + const cachedChat = makeChat("chat-1", { + has_unread: false, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + activeChatId: "chat-1", + }).has_unread, + ).toBe(false); + }); +}); + +describe("mergeWatchedChatIntoCaches", () => { + it("merges last_model_config_id into the root list cache and per-chat cache", () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + const cachedChat = makeChat(chatId, { + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat(chatId, { + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + seedInfiniteChats(queryClient, [cachedChat]); + queryClient.setQueryData(chatKey(chatId), cachedChat); + + mergeWatchedChatIntoCaches(queryClient, watchedChat, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(chatId)), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("merges last_model_config_id into the parent-embedded child snapshot and child cache", () => { + const queryClient = createTestQueryClient(); + const childId = "child-1"; + const cachedChild = makeChat(childId, { + parent_chat_id: "parent-1", + root_chat_id: "parent-1", + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const parent = makeChat("parent-1", { children: [cachedChild] }); + const watchedChild = makeChat(childId, { + parent_chat_id: "parent-1", + root_chat_id: "parent-1", + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + seedInfiniteChats(queryClient, [parent]); + queryClient.setQueryData(chatKey(childId), cachedChild); + + mergeWatchedChatIntoCaches(queryClient, watchedChild, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0].children?.[0]).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(childId)), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("does not let an older watch payload clobber newer cached metadata", () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + const cachedChat = makeChat(chatId, { + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + const staleWatchChat = makeChat(chatId, { + status: "running", + title: "Stale title", + last_model_config_id: "model-old", + workspace_id: "workspace-old", + build_id: "build-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + + seedInfiniteChats(queryClient, [cachedChat]); + queryClient.setQueryData(chatKey(chatId), cachedChat); + + mergeWatchedChatIntoCaches(queryClient, staleWatchChat, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({ + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(chatId)), + ).toMatchObject({ + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); +}); + describe("removeChildFromParentInCache", () => { it("removes the child from its parent's children array", () => { const queryClient = createTestQueryClient(); diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 5a59772971..5ab48bfbce 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -187,6 +187,180 @@ export const removeChildFromParentInCache = ( return found; }; +const parseUpdatedAtInstant = (updatedAt: string) => { + const match = updatedAt.match(/^(.*?)(?:\.(\d+))?(Z|[+-]\d\d:\d\d)$/); + if (!match) { + const epochMs = Date.parse(updatedAt); + return Number.isNaN(epochMs) ? undefined : { epochMs, fractionalNanos: 0 }; + } + + const [, timestampWithoutFraction, fractionalSeconds = "", timezone] = match; + const epochMs = Date.parse(`${timestampWithoutFraction}${timezone}`); + if (Number.isNaN(epochMs)) { + return undefined; + } + return { + epochMs, + fractionalNanos: Number(fractionalSeconds.slice(0, 9).padEnd(9, "0")), + }; +}; + +const compareUpdatedAtInstants = (a: string, b: string): number => { + const parsedA = parseUpdatedAtInstant(a); + const parsedB = parseUpdatedAtInstant(b); + if (!parsedA || !parsedB) { + return a.localeCompare(b); + } + if (parsedA.epochMs !== parsedB.epochMs) { + return parsedA.epochMs - parsedB.epochMs; + } + return parsedA.fractionalNanos - parsedB.fractionalNanos; +}; + +type MergeWatchedChatOptions = { + readonly eventKind: TypesGen.ChatWatchEventKind; + readonly activeChatId?: string; +}; + +// Shallow-compare two ChatDiffStatus objects by their meaningful +// fields, ignoring refreshed_at/stale_at which change on every poll. +const diffStatusEqual = ( + a: TypesGen.ChatDiffStatus | undefined, + b: TypesGen.ChatDiffStatus | undefined, +): boolean => { + if (a === b) { + return true; + } + if (!a || !b) { + return false; + } + return ( + a.url === b.url && + a.pull_request_state === b.pull_request_state && + a.pull_request_title === b.pull_request_title && + a.pull_request_draft === b.pull_request_draft && + a.changes_requested === b.changes_requested && + a.additions === b.additions && + a.deletions === b.deletions && + a.changed_files === b.changed_files && + a.pr_number === b.pr_number && + a.approved === b.approved && + a.commits === b.commits + ); +}; + +/** + * Merges event-scoped chat fields into a cached summary, using updated_at + * as a stale guard while still adopting the latest DB-backed model config. + */ +export const mergeWatchedChatSummary = ( + cachedChat: TypesGen.Chat, + watchedChat: TypesGen.Chat, + { eventKind, activeChatId }: MergeWatchedChatOptions, +): TypesGen.Chat => { + const isTitleEvent = eventKind === "title_change"; + const isStatusEvent = eventKind === "status_change"; + const isDiffStatusEvent = eventKind === "diff_status_change"; + const updatedAtComparison = compareUpdatedAtInstants( + cachedChat.updated_at, + watchedChat.updated_at, + ); + const isFreshEnough = updatedAtComparison <= 0; + const nextStatus = + isFreshEnough && isStatusEvent ? watchedChat.status : cachedChat.status; + // maybeGenerateChatTitle can publish a previously loaded chat snapshot, so + // apply title_change payloads even when the chat summary timestamp is older. + const nextTitle = isTitleEvent ? watchedChat.title : cachedChat.title; + // Diff status freshness is tracked outside chats.updated_at, so apply + // diff_status_change payloads even when the chat summary timestamp is older. + const nextDiffStatus = isDiffStatusEvent + ? watchedChat.diff_status + : cachedChat.diff_status; + const nextWorkspaceId = isFreshEnough + ? (watchedChat.workspace_id ?? cachedChat.workspace_id) + : cachedChat.workspace_id; + const nextBuildId = isFreshEnough + ? (watchedChat.build_id ?? cachedChat.build_id) + : cachedChat.build_id; + // All event types carry the current model config from the DB. + const nextLastModelConfigId = isFreshEnough + ? watchedChat.last_model_config_id + : cachedChat.last_model_config_id; + const nextHasUnread = + isFreshEnough && isStatusEvent && watchedChat.id !== activeChatId + ? true + : cachedChat.has_unread; + const nextUpdatedAt = + updatedAtComparison > 0 ? cachedChat.updated_at : watchedChat.updated_at; + + // Keep updated_at in the no-op guard. This gives up the old streaming + // rerender shortcut so later stale events cannot pass isFreshEnough + // against a timestamp that should already have been superseded. + if ( + nextStatus === cachedChat.status && + nextTitle === cachedChat.title && + diffStatusEqual(nextDiffStatus, cachedChat.diff_status) && + nextWorkspaceId === cachedChat.workspace_id && + nextBuildId === cachedChat.build_id && + nextLastModelConfigId === cachedChat.last_model_config_id && + nextHasUnread === cachedChat.has_unread && + nextUpdatedAt === cachedChat.updated_at + ) { + return cachedChat; + } + + return { + ...cachedChat, + status: nextStatus, + title: nextTitle, + diff_status: nextDiffStatus, + workspace_id: nextWorkspaceId, + build_id: nextBuildId, + last_model_config_id: nextLastModelConfigId, + has_unread: nextHasUnread, + updated_at: nextUpdatedAt, + }; +}; + +/** + * Applies the same event-scoped merge and stale guard across the list, + * parent-child, and per-chat caches, covering all three cache layers. + */ +export const mergeWatchedChatIntoCaches = ( + queryClient: QueryClient, + watchedChat: TypesGen.Chat, + options: MergeWatchedChatOptions, +) => { + const mergeCachedChat = (cachedChat: TypesGen.Chat) => + mergeWatchedChatSummary(cachedChat, watchedChat, options); + + updateInfiniteChatsCache(queryClient, (chats) => { + let didUpdate = false; + const nextChats = chats.map((chat) => { + if (chat.id !== watchedChat.id) { + return chat; + } + const mergedChat = mergeCachedChat(chat); + if (mergedChat !== chat) { + didUpdate = true; + } + return mergedChat; + }); + return didUpdate ? nextChats : chats; + }); + + updateChildInParentCache(queryClient, mergeCachedChat, watchedChat.id); + queryClient.setQueryData( + chatKey(watchedChat.id), + (cachedChat) => { + if (!cachedChat) { + return cachedChat; + } + return mergeCachedChat(cachedChat); + }, + ); +}; + const getNextOptimisticPinOrder = (queryClient: QueryClient): number => { let maxPinOrder = 0; const queries = queryClient.getQueriesData< diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 8886232468..dd35450810 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2168,6 +2168,7 @@ export const ChatProviderConfigSources: ChatProviderConfigSource[] = [ export interface ChatQueuedMessage { readonly id: number; readonly chat_id: string; + readonly model_config_id?: string; readonly content: readonly ChatMessagePart[]; readonly created_at: string; } diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index 621944ef91..ba074f37c0 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -20,6 +20,7 @@ import { chatsByWorkspaceKeyPrefix, infiniteChats, invalidateChatListQueries, + mergeWatchedChatIntoCaches, pinChat, prependToInfiniteChatsCache, readInfiniteChatsCache, @@ -29,7 +30,6 @@ import { unarchiveChat, unpinChat, updateChatTitle, - updateChildInParentCache, updateInfiniteChatsCache, } from "#/api/queries/chats"; import { workspaceById } from "#/api/queries/workspaces"; @@ -55,29 +55,6 @@ import { chatDetailErrorsEqual, } from "./utils/usageLimitMessage"; -// Shallow-compare two ChatDiffStatus objects by their meaningful -// fields, ignoring refreshed_at/stale_at which change on every poll. -function diffStatusEqual( - a: TypesGen.ChatDiffStatus | undefined, - b: TypesGen.ChatDiffStatus | undefined, -): boolean { - if (a === b) return true; - if (!a || !b) return false; - return ( - a.url === b.url && - a.pull_request_state === b.pull_request_state && - a.pull_request_title === b.pull_request_title && - a.pull_request_draft === b.pull_request_draft && - a.changes_requested === b.changes_requested && - a.additions === b.additions && - a.deletions === b.deletions && - a.changed_files === b.changed_files && - a.pr_number === b.pr_number && - a.approved === b.approved && - a.commits === b.commits - ); -} - export type { AgentsOutletContext } from "./AgentsPageView"; const AgentsPage: FC = () => { @@ -557,14 +534,8 @@ const AgentsPage: FC = () => { exact: true, }); } - // Scope field updates by event kind so that - // status_change events (which may carry a stale title - // snapshot from before async title generation - // finished) don't clobber a title_change that already - // landed. - const isTitleEvent = chatEvent.kind === "title_change"; - const isStatusEvent = chatEvent.kind === "status_change"; - const isDiffStatusEvent = chatEvent.kind === "diff_status_change"; + // Merge watch payloads by event kind so stale field + // snapshots do not clobber fresher cached metadata. // Cancel in-flight list and per-chat refetches so // they cannot overwrite the cache update below with @@ -606,117 +577,11 @@ const AgentsPage: FC = () => { prependToInfiniteChatsCache(queryClient, updatedChat); } } else { - // Build a field updater shared between root and - // child cache update paths. - const applyFields = (c: TypesGen.Chat): TypesGen.Chat => { - const nextStatus = isStatusEvent ? updatedChat.status : c.status; - const nextTitle = isTitleEvent ? updatedChat.title : c.title; - const nextDiffStatus = isDiffStatusEvent - ? updatedChat.diff_status - : c.diff_status; - const nextWorkspaceId = - updatedChat.workspace_id ?? c.workspace_id; - const nextBuildId = updatedChat.build_id ?? c.build_id; - const nextUpdatedAt = - c.updated_at > updatedChat.updated_at - ? c.updated_at - : updatedChat.updated_at; - // The server's pubsub path does not compute - // has_unread (it always sends false). For - // status_change events on non-active chats, - // optimistically mark as unread since the - // assistant produced new output. - const nextHasUnread = - isStatusEvent && updatedChat.id !== activeChatIDRef.current - ? true - : c.has_unread; - if ( - nextStatus === c.status && - nextTitle === c.title && - diffStatusEqual(nextDiffStatus, c.diff_status) && - nextWorkspaceId === c.workspace_id && - nextBuildId === c.build_id && - nextHasUnread === c.has_unread - ) { - return c; - } - return { - ...c, - status: nextStatus, - title: nextTitle, - diff_status: nextDiffStatus, - workspace_id: nextWorkspaceId, - build_id: nextBuildId, - updated_at: nextUpdatedAt, - has_unread: nextHasUnread, - }; - }; - - // Try root-level update first. - updateInfiniteChatsCache(queryClient, (chats) => { - let didUpdate = false; - const nextChats = chats.map((c) => { - if (c.id !== updatedChat.id) return c; - const result = applyFields(c); - if (result !== c) didUpdate = true; - return result; - }); - return didUpdate ? nextChats : chats; + mergeWatchedChatIntoCaches(queryClient, updatedChat, { + eventKind: chatEvent.kind, + activeChatId: activeChatIDRef.current, }); - - // Also update inside parent's children array - // in case the event targets a child chat. - updateChildInParentCache(queryClient, applyFields, updatedChat.id); } - queryClient.setQueryData( - chatKey(updatedChat.id), - (previousChat) => { - if (!previousChat) { - return previousChat; - } - // Only create a new object if a field actually - // changed. Returning the same reference prevents - // react-query from notifying subscribers, avoiding - // unnecessary re-renders of AgentChatPage during - // streaming when repeated status_change events - // carry the same "running" status. - const nextStatus = isStatusEvent - ? updatedChat.status - : previousChat.status; - const nextTitle = isTitleEvent - ? updatedChat.title - : previousChat.title; - const nextDiffStatus = isDiffStatusEvent - ? updatedChat.diff_status - : previousChat.diff_status; - const nextWorkspaceId = - updatedChat.workspace_id ?? previousChat.workspace_id; - const nextBuildId = updatedChat.build_id ?? previousChat.build_id; - const nextUpdatedAt = - previousChat.updated_at > updatedChat.updated_at - ? previousChat.updated_at - : updatedChat.updated_at; - - if ( - nextStatus === previousChat.status && - nextTitle === previousChat.title && - diffStatusEqual(nextDiffStatus, previousChat.diff_status) && - nextWorkspaceId === previousChat.workspace_id && - nextBuildId === previousChat.build_id - ) { - return previousChat; - } - return { - ...previousChat, - status: nextStatus, - title: nextTitle, - diff_status: nextDiffStatus, - workspace_id: nextWorkspaceId, - build_id: nextBuildId, - updated_at: nextUpdatedAt, - }; - }, - ); }); return ws; },