diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index 31be73030d..1e18ab8322 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -988,6 +988,7 @@ func (p *Server) Subscribe( ctx context.Context, chatID uuid.UUID, requestHeader http.Header, + afterMessageID int64, ) ( []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, @@ -1013,8 +1014,14 @@ func (p *Server) Subscribe( } } - // Load initial messages from DB - messages, err := p.db.GetChatMessagesByChatID(ctx, chatID) + // Load initial messages from DB. When afterMessageID > 0 the + // caller already has messages up to that ID (e.g. from the REST + // endpoint), so we only fetch newer ones to avoid sending + // duplicate data. + messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: afterMessageID, + }) if err == nil { for _, msg := range messages { sdkMsg := db2sdk.ChatMessage(msg) @@ -1191,23 +1198,24 @@ func (p *Server) Subscribe( case notify := <-notifications: // Handle different notification types if notify.AfterMessageID > 0 { - // Read new messages from DB - messages, err := p.db.GetChatMessagesByChatID(mergedCtx, chatID) + // Read only new messages from DB. + messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: lastMessageID, + }) if err == nil { for _, msg := range messages { - if msg.ID > lastMessageID { - sdkMsg := db2sdk.ChatMessage(msg) - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMsg, - }: - } - lastMessageID = msg.ID + sdkMsg := db2sdk.ChatMessage(msg) + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessage, + ChatID: chatID, + Message: &sdkMsg, + }: } + lastMessageID = msg.ID } } } diff --git a/coderd/chatd/chatd_test.go b/coderd/chatd/chatd_test.go index cfab2da299..1c494ed75d 100644 --- a/coderd/chatd/chatd_test.go +++ b/coderd/chatd/chatd_test.go @@ -27,6 +27,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/testutil" @@ -60,7 +61,7 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) { }) require.NoError(t, err) - _, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil) + _, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0) require.True(t, ok) t.Cleanup(cancel) @@ -202,7 +203,10 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) { require.NoError(t, err) require.Len(t, queued, 1) - messages, err := db.GetChatMessagesByChatID(ctx, chat.ID) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) require.NoError(t, err) require.Len(t, messages, 1) } @@ -252,7 +256,10 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) { require.NoError(t, err) require.Len(t, queued, 0) - messages, err := db.GetChatMessagesByChatID(ctx, chat.ID) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) require.NoError(t, err) require.Len(t, messages, 2) require.Equal(t, messages[len(messages)-1].ID, result.Message.ID) @@ -275,7 +282,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) { }) require.NoError(t, err) - initialMessages, err := db.GetChatMessagesByChatID(ctx, chat.ID) + initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) require.NoError(t, err) require.Len(t, initialMessages, 1) editedMessageID := initialMessages[0].ID @@ -322,7 +332,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) { require.Len(t, editedSDK.Content, 1) require.Equal(t, "edited", editedSDK.Content[0].Text) - messages, err := db.GetChatMessagesByChatID(ctx, chat.ID) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) require.NoError(t, err) require.Len(t, messages, 1) require.Equal(t, editedMessageID, messages[0].ID) @@ -657,7 +670,7 @@ func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) { }) require.NoError(t, err) - snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil) + snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0) require.True(t, ok) t.Cleanup(cancel) @@ -686,7 +699,7 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) { }) require.NoError(t, err) - snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil) + snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0) require.True(t, ok) t.Cleanup(cancel) @@ -708,6 +721,87 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) { } } +func TestSubscribeAfterMessageID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat — this inserts one initial "user" message. + chat, err := replica.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "after-id-test", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}}, + }) + require.NoError(t, err) + + // Insert two more messages so we have three total visible + // messages (the initial user message plus these two). + msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: "assistant", + Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true}, + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{}, + OutputTokens: sql.NullInt64{}, + TotalTokens: sql.NullInt64{}, + ReasoningTokens: sql.NullInt64{}, + CacheCreationTokens: sql.NullInt64{}, + CacheReadTokens: sql.NullInt64{}, + ContextLimit: sql.NullInt64{}, + Compressed: sql.NullBool{}, + }) + require.NoError(t, err) + + _, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: "user", + Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true}, + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{}, + OutputTokens: sql.NullInt64{}, + TotalTokens: sql.NullInt64{}, + ReasoningTokens: sql.NullInt64{}, + CacheCreationTokens: sql.NullInt64{}, + CacheReadTokens: sql.NullInt64{}, + ContextLimit: sql.NullInt64{}, + Compressed: sql.NullBool{}, + }) + require.NoError(t, err) + + // Control: Subscribe with afterMessageID=0 returns ALL messages. + allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + cancelAll() + + allMessages := filterMessageEvents(allSnapshot) + require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages") + + // Subscribe with afterMessageID set to the second message's ID. + // Only the third message (inserted after msg2) should appear. + partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID) + require.True(t, ok) + cancelPartial() + + partialMessages := filterMessageEvents(partialSnapshot) + require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2") + require.Equal(t, "user", partialMessages[0].Message.Role) +} + +// filterMessageEvents returns only the Message-type events from a +// snapshot slice, which is useful for ignoring status / queue events. +func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent { + return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool { + return e.Type == codersdk.ChatStreamEventTypeMessage + }) +} + func TestCreateWorkspaceTool_EndToEnd(t *testing.T) { t.Parallel() diff --git a/coderd/chatd/subagent.go b/coderd/chatd/subagent.go index ad2991fd09..9bafd360e3 100644 --- a/coderd/chatd/subagent.go +++ b/coderd/chatd/subagent.go @@ -397,7 +397,10 @@ func latestSubagentAssistantMessage( store database.Store, chatID uuid.UUID, ) (string, error) { - messages, err := store.GetChatMessagesByChatID(ctx, chatID) + messages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }) if err != nil { return "", xerrors.Errorf("get chat messages: %w", err) } diff --git a/coderd/chats.go b/coderd/chats.go index c706ae2db1..32539589c2 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -368,7 +368,10 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) { chat := httpmw.ChatParam(r) chatID := chat.ID - messages, err := api.Database.GetChatMessagesByChatID(ctx, chatID) + messages, err := api.Database.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to get chat messages.", @@ -681,7 +684,20 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { <-senderClosed }() - snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header) + var afterMessageID int64 + if v := r.URL.Query().Get("after_id"); v != "" { + var err error + afterMessageID, err = strconv.ParseInt(v, 10, 64) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid after_id parameter.", + Detail: err.Error(), + }) + return + } + } + + snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID) if !ok { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Chat streaming is not available.", diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6ba7e7710b..28878df18e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2465,13 +2465,13 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch return msg, nil } -func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) { +func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) { // Authorize read on the parent chat. - _, err := q.GetChatByID(ctx, chatID) + _, err := q.GetChatByID(ctx, arg.ChatID) if err != nil { return nil, err } - return q.db.GetChatMessagesByChatID(ctx, chatID) + return q.db.GetChatMessagesByChatID(ctx, arg) } func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index b1114c93fc..9787144913 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -473,9 +473,10 @@ func (s *MethodTestSuite) TestChats() { s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})} + arg := database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0} dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() - dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes() - check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs) + dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs) })) s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e7bc3562e5..e1b8bebf1c 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1007,7 +1007,7 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da return r0, r1 } -func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) { +func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID) m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds()) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6e1476a3d0..3d767d025c 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1838,18 +1838,18 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call { } // GetChatMessagesByChatID mocks base method. -func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) { +func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID) + ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, arg) ret0, _ := ret[0].([]database.ChatMessage) ret1, _ := ret[1].(error) return ret0, ret1 } // GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID. -func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call { +func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg) } // GetChatMessagesForPromptByChatID mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index bc207cd6cf..03a3660466 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -214,7 +214,7 @@ type sqlcQuerier interface { GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error) GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error) - GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) + GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error) GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3a3ef330a2..e8e37d3069 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3112,13 +3112,19 @@ FROM chat_messages WHERE chat_id = $1::uuid + AND id > $2::bigint AND visibility IN ('user', 'both') ORDER BY created_at ASC ` -func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) { - rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, chatID) +type GetChatMessagesByChatIDParams struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + AfterID int64 `db:"after_id" json:"after_id"` +} + +func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) { + rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, arg.ChatID, arg.AfterID) if err != nil { return nil, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index bf75a12e62..71ca871544 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -41,6 +41,7 @@ FROM chat_messages WHERE chat_id = @chat_id::uuid + AND id > @after_id::bigint AND visibility IN ('user', 'both') ORDER BY created_at ASC; diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 0b17e2fc2e..c1742cff45 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -140,9 +140,16 @@ export const watchWorkspace = ( export const watchChat = ( chatId: string, + afterMessageId?: number, ): OneWayWebSocket => { + const params = new URLSearchParams(); + if (afterMessageId !== undefined && afterMessageId > 0) { + params.set("after_id", afterMessageId.toString()); + } + const query = params.toString(); + const route = `/api/experimental/chats/${chatId}/stream${query ? `?${query}` : ""}`; return new OneWayWebSocket({ - apiRoute: `/api/experimental/chats/${chatId}/stream`, + apiRoute: route, }); }; diff --git a/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx b/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx index dd7eb4bb16..28306a7545 100644 --- a/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx +++ b/site/src/pages/AgentsPage/AgentDetail/ChatContext.test.tsx @@ -202,7 +202,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -283,7 +283,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -358,7 +358,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -460,7 +460,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); const streamBaseline = streamRenderCount; @@ -526,7 +526,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -601,7 +601,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -696,7 +696,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); expect(result.current.queuedMessages.map((message) => message.id)).toEqual([ queuedMessage.id, @@ -781,7 +781,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -852,7 +852,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID1); + expect(watchChat).toHaveBeenCalledWith(chatID1, 1); }); act(() => { @@ -888,7 +888,7 @@ describe("useChatStore", () => { }); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID2); + expect(watchChat).toHaveBeenCalledWith(chatID2, 10); }); // The old WebSocket was closed during effect cleanup. @@ -935,7 +935,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); act(() => { @@ -991,7 +991,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); // Build up stream state so we can observe whether it gets cleared. @@ -1093,7 +1093,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, 1); }); // Build up stream state first. @@ -1193,7 +1193,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID1); + expect(watchChat).toHaveBeenCalledWith(chatID1, 1); }); act(() => { @@ -1229,7 +1229,7 @@ describe("useChatStore", () => { }); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID2); + expect(watchChat).toHaveBeenCalledWith(chatID2, 10); }); expect(result.current.streamState).toBeNull(); @@ -1284,7 +1284,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID1); + expect(watchChat).toHaveBeenCalledWith(chatID1, 1); }); // Verify queued messages from chat-1 are present. @@ -1310,7 +1310,7 @@ describe("useChatStore", () => { // After the switch, queued messages from chat-1 should NOT be // visible — the store resets them on chatID change. await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID2); + expect(watchChat).toHaveBeenCalledWith(chatID2, undefined); }); expect(result.current.queuedMessages).toEqual([]); }); @@ -1352,7 +1352,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); // Emit a batch with message_parts followed by a status change @@ -1424,7 +1424,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); act(() => { @@ -1483,7 +1483,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); act(() => { @@ -1536,7 +1536,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); act(() => { @@ -1598,7 +1598,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); // Set retry state first. @@ -1676,7 +1676,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); act(() => { @@ -1734,7 +1734,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); act(() => { @@ -1783,7 +1783,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); // Set an error via an error stream event first. @@ -1847,7 +1847,7 @@ describe("useChatStore", () => { ); await waitFor(() => { - expect(watchChat).toHaveBeenCalledWith(chatID); + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); }); // Transition to running — should call clearChatErrorReason. diff --git a/site/src/pages/AgentsPage/AgentDetail/ChatContext.ts b/site/src/pages/AgentsPage/AgentDetail/ChatContext.ts index 10f2186ed5..9cd93d88dc 100644 --- a/site/src/pages/AgentsPage/AgentDetail/ChatContext.ts +++ b/site/src/pages/AgentsPage/AgentDetail/ChatContext.ts @@ -435,6 +435,17 @@ export const useChatStore = ( const store = storeRef.current; + // Compute the last REST-fetched message ID so the stream can + // skip messages the client already has. We use a ref so the + // socket effect can read the latest value without including + // chatMessages in its dependency array (which would cause + // unnecessary reconnections). + const lastMessageIdRef = useRef(undefined); + lastMessageIdRef.current = + chatMessages && chatMessages.length > 0 + ? chatMessages[chatMessages.length - 1].id + : undefined; + const updateSidebarChat = useCallback( (updater: (chat: TypesGen.Chat) => TypesGen.Chat) => { if (!chatID) { @@ -550,7 +561,9 @@ export const useChatStore = ( return; } - const socket = watchChat(chatID); + // Pass the last REST-fetched message ID so the stream + // only sends newer messages. + const socket = watchChat(chatID, lastMessageIdRef.current); const handleMessage = ( payload: OneWayMessageEvent, ) => {