diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 1d51b4299e..9ed76a5d30 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1027,6 +1027,12 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() check.Args(msg.ID).Asserts(chat, policy.ActionRead).Returns(msg) })) + s.Run("GetChatGoalMessageIDsByMessageIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + messageIDs := []int64{1, 2} + goalMessageIDs := []int64{2} + dbm.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), messageIDs).Return(goalMessageIDs, nil).AnyTimes() + check.Args(messageIDs).Asserts().Returns(goalMessageIDs) + })) s.Run("GetChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})} diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index af6ac6699c..88b171ead9 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -2159,6 +2159,7 @@ func TestSubscribeDedupesLocallyDeliveredMessageOnNotifyCatchup(t *testing.T) { ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{initialMessage.ID}).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), // DB catchup runs unconditionally on every notify; the delivered // set dedupes against locally-delivered messages. @@ -2208,6 +2209,7 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{initialMessage.ID}).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), // DB catchup runs unconditionally; cached id=2 is deduped via // the delivered set so this query returning nil is sufficient. @@ -2265,11 +2267,13 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) { ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{initialMessage.ID}).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 1, }).Return([]database.ChatMessage{catchupMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{catchupMessage.ID}).Return(nil, nil), ) server := newSubscribeTestServer(t, db) @@ -2314,11 +2318,13 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) { ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{initialMessage.ID}).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{editedMessage}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{editedMessage.ID}).Return(nil, nil), ) server := newSubscribeTestServer(t, db) diff --git a/coderd/x/chatd/subscribe_out_of_order_internal_test.go b/coderd/x/chatd/subscribe_out_of_order_internal_test.go index a6bb083785..ed8bbdbc49 100644 --- a/coderd/x/chatd/subscribe_out_of_order_internal_test.go +++ b/coderd/x/chatd/subscribe_out_of_order_internal_test.go @@ -38,6 +38,7 @@ func TestSubscribeDeliversOutOfOrderDurableMessage(t *testing.T) { ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialUser, initialAssistant}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{initialUser.ID, initialAssistant.ID}).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), ) // Notify-driven catch-up queries return nothing so the test only @@ -177,6 +178,7 @@ func TestSubscribeRunsDBFallbackWhenCacheDeliversUnrelatedMessage(t *testing.T) ChatID: chatID, AfterID: 5, }).Return([]database.ChatMessage{crossReplica}, nil), + db.EXPECT().GetChatGoalMessageIDsByMessageIDs(gomock.Any(), []int64{crossReplica.ID}).Return(nil, nil), ) server := newSubscribeTestServer(t, db)