fix: use cursor-based query for chat stream notifications (#22510)

## Problem

The pubsub notification handler in `chatd` re-fetched **all** messages
from the DB on every new message notification, then filtered in Go with
`msg.ID > lastMessageID`. This grows linearly with conversation length —
every new message triggers a full table scan of that chat's history.

The `AfterMessageID` field in the pubsub notification payload was
clearly designed for cursor-based fetching, but no matching query
existed.

## Fix

- Add `GetChatMessagesByChatIDAfter` SQL query with `WHERE id >
@after_id`, so the database does the filtering instead of Go.
- Use it in the pubsub notification handler in `chatd.go`, passing
`lastMessageID` as the cursor.
- Implement the dbauthz wrapper (was a `panic("not implemented")` stub
from codegen) with the same read-check-on-parent-chat pattern as
adjacent methods.
- Add dbauthz test coverage for the new method.

**Not changed:** The initial snapshot in `Subscribe()` still loads all
messages — that's correct, since a newly-connecting client needs the
full conversation state. The waste was only in the ongoing notification
path.
This commit is contained in:
Kyle Carberry
2026-03-02 16:31:04 -05:00
committed by GitHub
parent e3c5d734ba
commit 5eebd3829f
14 changed files with 216 additions and 67 deletions
+24 -16
View File
@@ -988,6 +988,7 @@ func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
requestHeader http.Header,
afterMessageID int64,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
@@ -1013,8 +1014,14 @@ func (p *Server) Subscribe(
}
}
// Load initial messages from DB
messages, err := p.db.GetChatMessagesByChatID(ctx, chatID)
// Load initial messages from DB. When afterMessageID > 0 the
// caller already has messages up to that ID (e.g. from the REST
// endpoint), so we only fetch newer ones to avoid sending
// duplicate data.
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: afterMessageID,
})
if err == nil {
for _, msg := range messages {
sdkMsg := db2sdk.ChatMessage(msg)
@@ -1191,23 +1198,24 @@ func (p *Server) Subscribe(
case notify := <-notifications:
// Handle different notification types
if notify.AfterMessageID > 0 {
// Read new messages from DB
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, chatID)
// Read only new messages from DB.
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: lastMessageID,
})
if err == nil {
for _, msg := range messages {
if msg.ID > lastMessageID {
sdkMsg := db2sdk.ChatMessage(msg)
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &sdkMsg,
}:
}
lastMessageID = msg.ID
sdkMsg := db2sdk.ChatMessage(msg)
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &sdkMsg,
}:
}
lastMessageID = msg.ID
}
}
}
+101 -7
View File
@@ -27,6 +27,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/testutil"
@@ -60,7 +61,7 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
})
require.NoError(t, err)
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
@@ -202,7 +203,10 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
require.NoError(t, err)
require.Len(t, queued, 1)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
}
@@ -252,7 +256,10 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
require.NoError(t, err)
require.Len(t, queued, 0)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 2)
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
@@ -275,7 +282,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
})
require.NoError(t, err)
initialMessages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, initialMessages, 1)
editedMessageID := initialMessages[0].ID
@@ -322,7 +332,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
require.Len(t, editedSDK.Content, 1)
require.Equal(t, "edited", editedSDK.Content[0].Text)
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, editedMessageID, messages[0].ID)
@@ -657,7 +670,7 @@ func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
})
require.NoError(t, err)
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil)
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
@@ -686,7 +699,7 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
})
require.NoError(t, err)
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil)
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
@@ -708,6 +721,87 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
}
}
func TestSubscribeAfterMessageID(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Create a chat — this inserts one initial "user" message.
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "after-id-test",
ModelConfigID: model.ID,
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}},
})
require.NoError(t, err)
// Insert two more messages so we have three total visible
// messages (the initial user message plus these two).
msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: "assistant",
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true},
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
})
require.NoError(t, err)
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: "user",
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true},
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
})
require.NoError(t, err)
// Control: Subscribe with afterMessageID=0 returns ALL messages.
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
cancelAll()
allMessages := filterMessageEvents(allSnapshot)
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
// Subscribe with afterMessageID set to the second message's ID.
// Only the third message (inserted after msg2) should appear.
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
require.True(t, ok)
cancelPartial()
partialMessages := filterMessageEvents(partialSnapshot)
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
require.Equal(t, "user", partialMessages[0].Message.Role)
}
// filterMessageEvents returns only the Message-type events from a
// snapshot slice, which is useful for ignoring status / queue events.
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
return e.Type == codersdk.ChatStreamEventTypeMessage
})
}
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
t.Parallel()
+4 -1
View File
@@ -397,7 +397,10 @@ func latestSubagentAssistantMessage(
store database.Store,
chatID uuid.UUID,
) (string, error) {
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
messages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
return "", xerrors.Errorf("get chat messages: %w", err)
}
+18 -2
View File
@@ -368,7 +368,10 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
chat := httpmw.ChatParam(r)
chatID := chat.ID
messages, err := api.Database.GetChatMessagesByChatID(ctx, chatID)
messages, err := api.Database.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat messages.",
@@ -681,7 +684,20 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
<-senderClosed
}()
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header)
var afterMessageID int64
if v := r.URL.Query().Get("after_id"); v != "" {
var err error
afterMessageID, err = strconv.ParseInt(v, 10, 64)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid after_id parameter.",
Detail: err.Error(),
})
return
}
}
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
if !ok {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Chat streaming is not available.",
+3 -3
View File
@@ -2465,13 +2465,13 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
return msg, nil
}
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatID(ctx, chatID)
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
+3 -2
View File
@@ -473,9 +473,10 @@ func (s *MethodTestSuite) TestChats() {
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
+1 -1
View File
@@ -1007,7 +1007,7 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
+4 -4
View File
@@ -1838,18 +1838,18 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
}
// GetChatMessagesByChatID mocks base method.
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID)
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID.
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
}
// GetChatMessagesForPromptByChatID mocks base method.
+1 -1
View File
@@ -214,7 +214,7 @@ type sqlcQuerier interface {
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error)
+8 -2
View File
@@ -3112,13 +3112,19 @@ FROM
chat_messages
WHERE
chat_id = $1::uuid
AND id > $2::bigint
AND visibility IN ('user', 'both')
ORDER BY
created_at ASC
`
func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) {
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, chatID)
type GetChatMessagesByChatIDParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
AfterID int64 `db:"after_id" json:"after_id"`
}
func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) {
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, arg.ChatID, arg.AfterID)
if err != nil {
return nil, err
}
+1
View File
@@ -41,6 +41,7 @@ FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND id > @after_id::bigint
AND visibility IN ('user', 'both')
ORDER BY
created_at ASC;
+8 -1
View File
@@ -140,9 +140,16 @@ export const watchWorkspace = (
export const watchChat = (
chatId: string,
afterMessageId?: number,
): OneWayWebSocket<TypesGen.ServerSentEvent> => {
const params = new URLSearchParams();
if (afterMessageId !== undefined && afterMessageId > 0) {
params.set("after_id", afterMessageId.toString());
}
const query = params.toString();
const route = `/api/experimental/chats/${chatId}/stream${query ? `?${query}` : ""}`;
return new OneWayWebSocket({
apiRoute: `/api/experimental/chats/${chatId}/stream`,
apiRoute: route,
});
};
@@ -202,7 +202,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -283,7 +283,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -358,7 +358,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -460,7 +460,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
const streamBaseline = streamRenderCount;
@@ -526,7 +526,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -601,7 +601,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -696,7 +696,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
expect(result.current.queuedMessages.map((message) => message.id)).toEqual([
queuedMessage.id,
@@ -781,7 +781,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -852,7 +852,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID1);
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
});
act(() => {
@@ -888,7 +888,7 @@ describe("useChatStore", () => {
});
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID2);
expect(watchChat).toHaveBeenCalledWith(chatID2, 10);
});
// The old WebSocket was closed during effect cleanup.
@@ -935,7 +935,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
act(() => {
@@ -991,7 +991,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
// Build up stream state so we can observe whether it gets cleared.
@@ -1093,7 +1093,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
});
// Build up stream state first.
@@ -1193,7 +1193,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID1);
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
});
act(() => {
@@ -1229,7 +1229,7 @@ describe("useChatStore", () => {
});
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID2);
expect(watchChat).toHaveBeenCalledWith(chatID2, 10);
});
expect(result.current.streamState).toBeNull();
@@ -1284,7 +1284,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID1);
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
});
// Verify queued messages from chat-1 are present.
@@ -1310,7 +1310,7 @@ describe("useChatStore", () => {
// After the switch, queued messages from chat-1 should NOT be
// visible — the store resets them on chatID change.
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID2);
expect(watchChat).toHaveBeenCalledWith(chatID2, undefined);
});
expect(result.current.queuedMessages).toEqual([]);
});
@@ -1352,7 +1352,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
// Emit a batch with message_parts followed by a status change
@@ -1424,7 +1424,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
@@ -1483,7 +1483,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
@@ -1536,7 +1536,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
@@ -1598,7 +1598,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
// Set retry state first.
@@ -1676,7 +1676,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
@@ -1734,7 +1734,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
act(() => {
@@ -1783,7 +1783,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
// Set an error via an error stream event first.
@@ -1847,7 +1847,7 @@ describe("useChatStore", () => {
);
await waitFor(() => {
expect(watchChat).toHaveBeenCalledWith(chatID);
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
});
// Transition to running — should call clearChatErrorReason.
@@ -435,6 +435,17 @@ export const useChatStore = (
const store = storeRef.current;
// Compute the last REST-fetched message ID so the stream can
// skip messages the client already has. We use a ref so the
// socket effect can read the latest value without including
// chatMessages in its dependency array (which would cause
// unnecessary reconnections).
const lastMessageIdRef = useRef<number | undefined>(undefined);
lastMessageIdRef.current =
chatMessages && chatMessages.length > 0
? chatMessages[chatMessages.length - 1].id
: undefined;
const updateSidebarChat = useCallback(
(updater: (chat: TypesGen.Chat) => TypesGen.Chat) => {
if (!chatID) {
@@ -550,7 +561,9 @@ export const useChatStore = (
return;
}
const socket = watchChat(chatID);
// Pass the last REST-fetched message ID so the stream
// only sends newer messages.
const socket = watchChat(chatID, lastMessageIdRef.current);
const handleMessage = (
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
) => {