mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: use cursor-based query for chat stream notifications (#22510)
## Problem
The pubsub notification handler in `chatd` re-fetched **all** messages
from the DB on every new message notification, then filtered in Go with
`msg.ID > lastMessageID`. This grows linearly with conversation length —
every new message triggers a full table scan of that chat's history.
The `AfterMessageID` field in the pubsub notification payload was
clearly designed for cursor-based fetching, but no matching query
existed.
## Fix
- Add `GetChatMessagesByChatIDAfter` SQL query with `WHERE id >
@after_id`, so the database does the filtering instead of Go.
- Use it in the pubsub notification handler in `chatd.go`, passing
`lastMessageID` as the cursor.
- Implement the dbauthz wrapper (was a `panic("not implemented")` stub
from codegen) with the same read-check-on-parent-chat pattern as
adjacent methods.
- Add dbauthz test coverage for the new method.
**Not changed:** The initial snapshot in `Subscribe()` still loads all
messages — that's correct, since a newly-connecting client needs the
full conversation state. The waste was only in the ongoing notification
path.
This commit is contained in:
+24
-16
@@ -988,6 +988,7 @@ func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
afterMessageID int64,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
@@ -1013,8 +1014,14 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
// Load initial messages from DB
|
||||
messages, err := p.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
// Load initial messages from DB. When afterMessageID > 0 the
|
||||
// caller already has messages up to that ID (e.g. from the REST
|
||||
// endpoint), so we only fetch newer ones to avoid sending
|
||||
// duplicate data.
|
||||
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: afterMessageID,
|
||||
})
|
||||
if err == nil {
|
||||
for _, msg := range messages {
|
||||
sdkMsg := db2sdk.ChatMessage(msg)
|
||||
@@ -1191,23 +1198,24 @@ func (p *Server) Subscribe(
|
||||
case notify := <-notifications:
|
||||
// Handle different notification types
|
||||
if notify.AfterMessageID > 0 {
|
||||
// Read new messages from DB
|
||||
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, chatID)
|
||||
// Read only new messages from DB.
|
||||
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: lastMessageID,
|
||||
})
|
||||
if err == nil {
|
||||
for _, msg := range messages {
|
||||
if msg.ID > lastMessageID {
|
||||
sdkMsg := db2sdk.ChatMessage(msg)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &sdkMsg,
|
||||
}:
|
||||
}
|
||||
lastMessageID = msg.ID
|
||||
sdkMsg := db2sdk.ChatMessage(msg)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &sdkMsg,
|
||||
}:
|
||||
}
|
||||
lastMessageID = msg.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+101
-7
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -60,7 +61,7 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil)
|
||||
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@@ -202,7 +203,10 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 1)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
@@ -252,7 +256,10 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
||||
@@ -275,7 +282,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
initialMessages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, initialMessages, 1)
|
||||
editedMessageID := initialMessages[0].ID
|
||||
@@ -322,7 +332,10 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
require.Len(t, editedSDK.Content, 1)
|
||||
require.Equal(t, "edited", editedSDK.Content[0].Text)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, chat.ID)
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Equal(t, editedMessageID, messages[0].ID)
|
||||
@@ -657,7 +670,7 @@ func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil)
|
||||
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@@ -686,7 +699,7 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil)
|
||||
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@@ -708,6 +721,87 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a chat — this inserts one initial "user" message.
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "after-id-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert two more messages so we have three total visible
|
||||
// messages (the initial user message plus these two).
|
||||
msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true},
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "user",
|
||||
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true},
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Control: Subscribe with afterMessageID=0 returns ALL messages.
|
||||
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
cancelAll()
|
||||
|
||||
allMessages := filterMessageEvents(allSnapshot)
|
||||
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
|
||||
|
||||
// Subscribe with afterMessageID set to the second message's ID.
|
||||
// Only the third message (inserted after msg2) should appear.
|
||||
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
|
||||
require.True(t, ok)
|
||||
cancelPartial()
|
||||
|
||||
partialMessages := filterMessageEvents(partialSnapshot)
|
||||
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
|
||||
require.Equal(t, "user", partialMessages[0].Message.Role)
|
||||
}
|
||||
|
||||
// filterMessageEvents returns only the Message-type events from a
|
||||
// snapshot slice, which is useful for ignoring status / queue events.
|
||||
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
|
||||
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
|
||||
return e.Type == codersdk.ChatStreamEventTypeMessage
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -397,7 +397,10 @@ func latestSubagentAssistantMessage(
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
) (string, error) {
|
||||
messages, err := store.GetChatMessagesByChatID(ctx, chatID)
|
||||
messages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
|
||||
+18
-2
@@ -368,7 +368,10 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
|
||||
messages, err := api.Database.GetChatMessagesByChatID(ctx, chatID)
|
||||
messages, err := api.Database.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat messages.",
|
||||
@@ -681,7 +684,20 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header)
|
||||
var afterMessageID int64
|
||||
if v := r.URL.Query().Get("after_id"); v != "" {
|
||||
var err error
|
||||
afterMessageID, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid after_id parameter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
|
||||
@@ -2465,13 +2465,13 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatID(ctx, chatID)
|
||||
return q.db.GetChatMessagesByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
|
||||
@@ -473,9 +473,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), chat.ID).Return(msgs, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
|
||||
@@ -1007,7 +1007,7 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatID").Observe(time.Since(start).Seconds())
|
||||
|
||||
@@ -1838,18 +1838,18 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, chatID)
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID indicates an expected call of GetChatMessagesByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, chatID any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, chatID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID mocks base method.
|
||||
|
||||
@@ -214,7 +214,7 @@ type sqlcQuerier interface {
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigByProviderAndModel(ctx context.Context, arg GetChatModelConfigByProviderAndModelParams) (ChatModelConfig, error)
|
||||
|
||||
@@ -3112,13 +3112,19 @@ FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = $1::uuid
|
||||
AND id > $2::bigint
|
||||
AND visibility IN ('user', 'both')
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, chatID)
|
||||
type GetChatMessagesByChatIDParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
AfterID int64 `db:"after_id" json:"after_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatID, arg.ChatID, arg.AfterID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND id > @after_id::bigint
|
||||
AND visibility IN ('user', 'both')
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
+8
-1
@@ -140,9 +140,16 @@ export const watchWorkspace = (
|
||||
|
||||
export const watchChat = (
|
||||
chatId: string,
|
||||
afterMessageId?: number,
|
||||
): OneWayWebSocket<TypesGen.ServerSentEvent> => {
|
||||
const params = new URLSearchParams();
|
||||
if (afterMessageId !== undefined && afterMessageId > 0) {
|
||||
params.set("after_id", afterMessageId.toString());
|
||||
}
|
||||
const query = params.toString();
|
||||
const route = `/api/experimental/chats/${chatId}/stream${query ? `?${query}` : ""}`;
|
||||
return new OneWayWebSocket({
|
||||
apiRoute: `/api/experimental/chats/${chatId}/stream`,
|
||||
apiRoute: route,
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -202,7 +202,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -283,7 +283,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -358,7 +358,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -460,7 +460,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
const streamBaseline = streamRenderCount;
|
||||
@@ -526,7 +526,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -601,7 +601,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -696,7 +696,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
expect(result.current.queuedMessages.map((message) => message.id)).toEqual([
|
||||
queuedMessage.id,
|
||||
@@ -781,7 +781,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -852,7 +852,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -888,7 +888,7 @@ describe("useChatStore", () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2, 10);
|
||||
});
|
||||
|
||||
// The old WebSocket was closed during effect cleanup.
|
||||
@@ -935,7 +935,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -991,7 +991,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
// Build up stream state so we can observe whether it gets cleared.
|
||||
@@ -1093,7 +1093,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
// Build up stream state first.
|
||||
@@ -1193,7 +1193,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1229,7 +1229,7 @@ describe("useChatStore", () => {
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2, 10);
|
||||
});
|
||||
|
||||
expect(result.current.streamState).toBeNull();
|
||||
@@ -1284,7 +1284,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID1, 1);
|
||||
});
|
||||
|
||||
// Verify queued messages from chat-1 are present.
|
||||
@@ -1310,7 +1310,7 @@ describe("useChatStore", () => {
|
||||
// After the switch, queued messages from chat-1 should NOT be
|
||||
// visible — the store resets them on chatID change.
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID2, undefined);
|
||||
});
|
||||
expect(result.current.queuedMessages).toEqual([]);
|
||||
});
|
||||
@@ -1352,7 +1352,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
// Emit a batch with message_parts followed by a status change
|
||||
@@ -1424,7 +1424,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1483,7 +1483,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1536,7 +1536,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1598,7 +1598,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
// Set retry state first.
|
||||
@@ -1676,7 +1676,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1734,7 +1734,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
@@ -1783,7 +1783,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
// Set an error via an error stream event first.
|
||||
@@ -1847,7 +1847,7 @@ describe("useChatStore", () => {
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID);
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
// Transition to running — should call clearChatErrorReason.
|
||||
|
||||
@@ -435,6 +435,17 @@ export const useChatStore = (
|
||||
|
||||
const store = storeRef.current;
|
||||
|
||||
// Compute the last REST-fetched message ID so the stream can
|
||||
// skip messages the client already has. We use a ref so the
|
||||
// socket effect can read the latest value without including
|
||||
// chatMessages in its dependency array (which would cause
|
||||
// unnecessary reconnections).
|
||||
const lastMessageIdRef = useRef<number | undefined>(undefined);
|
||||
lastMessageIdRef.current =
|
||||
chatMessages && chatMessages.length > 0
|
||||
? chatMessages[chatMessages.length - 1].id
|
||||
: undefined;
|
||||
|
||||
const updateSidebarChat = useCallback(
|
||||
(updater: (chat: TypesGen.Chat) => TypesGen.Chat) => {
|
||||
if (!chatID) {
|
||||
@@ -550,7 +561,9 @@ export const useChatStore = (
|
||||
return;
|
||||
}
|
||||
|
||||
const socket = watchChat(chatID);
|
||||
// Pass the last REST-fetched message ID so the stream
|
||||
// only sends newer messages.
|
||||
const socket = watchChat(chatID, lastMessageIdRef.current);
|
||||
const handleMessage = (
|
||||
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
|
||||
) => {
|
||||
|
||||
Reference in New Issue
Block a user