diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index d4dc4451ad..4f7f9eb7f4 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3261,7 +3261,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { // Subscribe before accepting the WebSocket so that failures // can still be reported as normal HTTP errors. - snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID) + snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID) // Subscribe only fails today when the receiver is nil, which // the chatDaemon == nil guard above already catches. This is // defensive against future Subscribe failure modes. diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 822909b1bf..bd502f9fcf 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -107,6 +107,9 @@ const ( // cross-replica relay subscribers time to connect and // snapshot the buffer before it is garbage-collected. bufferRetainGracePeriod = 5 * time.Second + // chatStreamControlFetchTimeout bounds subscriber-owned + // control-path DB reads when the caller has no deadline. + chatStreamControlFetchTimeout = 5 * time.Second // streamJanitorInterval is how often sweepIdleStreams runs. // Worst-case retention is bufferRetainGracePeriod + @@ -4244,6 +4247,31 @@ func (p *Server) heartbeatTick(ctx context.Context) { } } +// streamSubscriberControlFetchContext keeps a control-path lookup tied to the +// requesting subscriber while applying a fallback timeout when the caller has +// no deadline. +func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) { + if _, ok := ctx.Deadline(); ok { + return ctx, func() {} + } + return context.WithTimeout(ctx, chatStreamControlFetchTimeout) +} + +func subscribeWithInitialError(chatID uuid.UUID, message string) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + bool, +) { + events := make(chan codersdk.ChatStreamEvent) + close(events) + return []codersdk.ChatStreamEvent{{ + Type: codersdk.ChatStreamEventTypeError, + ChatID: chatID, + Error: &codersdk.ChatError{Message: message}, + }}, events, func() {}, true +} + func (p *Server) Subscribe( ctx context.Context, chatID uuid.UUID, @@ -4258,9 +4286,40 @@ func (p *Server) Subscribe( if p == nil { return nil, nil, nil, false } - if ctx == nil { - ctx = context.Background() + + chat, err := p.db.GetChatByID(ctx, chatID) + if err != nil { + if dbauthz.IsNotAuthorizedError(err) { + return nil, nil, nil, false + } + p.logger.Warn(ctx, "failed to load chat for stream subscription", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return subscribeWithInitialError(chatID, "failed to load initial snapshot") } + return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID) +} + +// SubscribeAuthorized subscribes an already-authorized chat to merged stream +// updates. The passed chat row proves authorization, but SubscribeAuthorized +// still reloads the chat after the stream subscriptions are armed so the +// initial status and relay setup use fresh state. +func (p *Server) SubscribeAuthorized( + ctx context.Context, + chat database.Chat, + requestHeader http.Header, + afterMessageID int64, +) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + bool, +) { + if p == nil { + return nil, nil, nil, false + } + chatID := chat.ID // Subscribe to the local stream for message_parts and same-replica // persisted messages. Capture the current retry phase under the same @@ -4326,6 +4385,34 @@ func (p *Server) Subscribe( } } + cancel := func() { + mergedCancel() + for _, cancelFn := range allCancels { + if cancelFn != nil { + cancelFn() + } + } + } + + // Re-read the chat after the local/pubsub subscriptions are active so + // the initial status event and any enterprise relay setup use fresh + // state instead of the middleware-loaded row. + refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx) + snapshotChat, err := func() (database.Chat, error) { + defer refreshCancel() + //nolint:gocritic // SubscribeAuthorized already validated the + // caller; this refresh only loads the latest status/worker for + // the already-authorized stream subscription. + return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID) + }() + if err != nil { + p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state", + slog.F("chat_id", chatID), + slog.Error(err), + ) + snapshotChat = chat + } + // Build initial snapshot synchronously. The pubsub subscription // is already active so no notifications can be lost during this // window. @@ -4377,8 +4464,12 @@ func (p *Server) Subscribe( } } - // Load initial queue. - queued, err := p.db.GetChatQueuedMessages(ctx, chatID) + // Load initial queue. Queue snapshots are intentionally not + // singleflighted because a chat-scoped key cannot distinguish the + // pre- and post-notification queue state. + queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx) + queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID) + queueCancel() if err != nil { p.logger.Error(ctx, "failed to load initial queued messages", slog.Error(err), @@ -4397,44 +4488,24 @@ func (p *Server) Subscribe( }) } - // Get initial chat state to determine if we need a relay. - chat, chatErr := p.db.GetChatByID(ctx, chatID) - // Include the current chat status in the snapshot so the // frontend can gate message_part processing correctly from // the very first batch, without waiting for a separate REST // query. - if chatErr != nil { - p.logger.Error(ctx, "failed to load initial chat state", - slog.Error(chatErr), - slog.F("chat_id", chatID), - ) - initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{Message: "failed to load initial snapshot"}, - }) - } else { - statusEvent := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeStatus, - ChatID: chatID, - Status: &codersdk.ChatStreamStatus{ - Status: codersdk.ChatStatus(chat.Status), - }, - } - // Prepend so the frontend sees the current stream phases - // before any message_part events. - prefix := []codersdk.ChatStreamEvent{statusEvent} - if retryEvent != nil { - prefix = append(prefix, *retryEvent) - retryEvent = nil - } - initialSnapshot = append(prefix, initialSnapshot...) + statusEvent := codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeStatus, + ChatID: chatID, + Status: &codersdk.ChatStreamStatus{ + Status: codersdk.ChatStatus(snapshotChat.Status), + }, } - + // Prepend so the frontend sees the current stream phases + // before any message_part events. + prefix := []codersdk.ChatStreamEvent{statusEvent} if retryEvent != nil { - initialSnapshot = append(initialSnapshot, *retryEvent) + prefix = append(prefix, *retryEvent) } + initialSnapshot = append(prefix, initialSnapshot...) // Track the highest durable message ID delivered to this subscriber, // whether it came from the initial DB snapshot, the same-replica local @@ -4444,18 +4515,17 @@ func (p *Server) Subscribe( lastMessageID = messages[len(messages)-1].ID } - // When an enterprise SubscribeFn is provided and the chat - // lookup succeeded, call it to get relay events (message_parts - // from remote replicas). OSS now owns pubsub subscription, - // message catch-up, queue updates, and status forwarding; - // enterprise only manages relay dialing. + // When an enterprise SubscribeFn is provided, call it to get relay events + // (message_parts from remote replicas). OSS owns pubsub subscription, + // message catch-up, queue updates, and status forwarding; enterprise only + // manages relay dialing. var relayEvents <-chan codersdk.ChatStreamEvent var statusNotifications chan StatusNotification - if p.subscribeFn != nil && chatErr == nil { + if p.subscribeFn != nil { statusNotifications = make(chan StatusNotification, 10) relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{ ChatID: chatID, - Chat: chat, + Chat: snapshotChat, WorkerID: p.workerID, StatusNotifications: statusNotifications, RequestHeader: requestHeader, @@ -4600,7 +4670,9 @@ func (p *Server) Subscribe( } } if notify.QueueUpdate { - queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID) + queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx) + queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID) + queueCancel() if queueErr != nil { p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification", slog.F("chat_id", chatID), @@ -4676,14 +4748,6 @@ func (p *Server) Subscribe( } }() - cancel := func() { - mergedCancel() - for _, cancelFn := range allCancels { - if cancelFn != nil { - cancelFn() - } - } - } return initialSnapshot, mergedEvents, cancel, true } diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index d58828afd6..bd1d33f788 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -1983,14 +1983,14 @@ func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) { ChatID: chatID, Role: database.ChatMessageRoleAssistant, } - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newSubscribeTestServer(t, db) @@ -2026,14 +2026,14 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) ChatID: chatID, Role: codersdk.ChatMessageRoleAssistant, } - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newSubscribeTestServer(t, db) @@ -2077,14 +2077,14 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) { ChatID: chatID, Role: database.ChatMessageRoleAssistant, } - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 1, @@ -2126,14 +2126,14 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) { ChatID: chatID, Role: database.ChatMessageRoleUser, } - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return([]database.ChatMessage{initialMessage}, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, @@ -2163,14 +2163,14 @@ func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) { chatID := uuid.New() chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newSubscribeTestServer(t, db) @@ -2200,12 +2200,13 @@ func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) { chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newBufferedSubscribeTestServer(t, db, chatID) @@ -2241,6 +2242,8 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) { server := newSubscribeTestServer(t, db) gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, @@ -2249,7 +2252,6 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) { return nil, nil }), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) @@ -2275,12 +2277,13 @@ func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) { chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newBufferedSubscribeTestServer(t, db, chatID) @@ -2310,12 +2313,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) { chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newBufferedSubscribeTestServer(t, db, chatID) @@ -2350,12 +2354,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) { chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted} gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newBufferedSubscribeTestServer(t, db, chatID) @@ -2382,14 +2387,14 @@ func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) { chatID := uuid.New() chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newSubscribeTestServer(t, db) @@ -2422,14 +2427,14 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) { chatID := uuid.New() chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ ChatID: chatID, AfterID: 0, }).Return(nil, nil), db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), ) server := newSubscribeTestServer(t, db) @@ -2461,6 +2466,101 @@ func newTestRetryPayload() *codersdk.ChatStreamRetry { return payload } +func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + server := newSubscribeTestServer(t, db) + + chatID := uuid.New() + staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending} + + state := server.getOrCreateStreamState(chatID) + state.mu.Lock() + state.buffer = []codersdk.ChatStreamEvent{{ + Type: codersdk.ChatStreamEventTypeMessagePart, + ChatID: chatID, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessageText("thinking"), + }, + }} + state.mu.Unlock() + + gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")), + db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }).Return(nil, nil), + db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), + ) + + initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0) + require.True(t, ok) + defer cancel() + + require.Len(t, initialSnapshot, 2) + require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type) + require.NotNil(t, initialSnapshot[0].Status) + require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status) + require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type) + require.NotNil(t, initialSnapshot[1].MessagePart) + require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text) + requireNoStreamEvent(t, events, 200*time.Millisecond) +} + +func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + server := newSubscribeTestServer(t, db) + + chatID := uuid.New() + db.EXPECT().GetChatByID(gomock.Any(), chatID). + Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")}) + + snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) + require.False(t, ok) + require.Nil(t, snapshot) + require.Nil(t, events) + require.Nil(t, cancel) + + _, exists := server.chatStreams.Load(chatID) + require.False(t, exists) +} + +func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + server := newSubscribeTestServer(t, db) + + chatID := uuid.New() + db.EXPECT().GetChatByID(gomock.Any(), chatID). + Return(database.Chat{}, xerrors.New("transient lookup failure")) + + snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) + require.True(t, ok) + require.NotNil(t, cancel) + require.Len(t, snapshot, 1) + require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type) + require.Equal(t, chatID, snapshot[0].ChatID) + require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message) + + _, open := <-events + require.False(t, open) + + _, exists := server.chatStreams.Load(chatID) + require.False(t, exists) +} + func newSubscribeTestServer(t *testing.T, db database.Store) *Server { t.Helper() diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index 2cfa3da157..ad8d0867a1 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -433,8 +433,13 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) { user, org, model := seedChatDependencies(t, db) chat := seedRemoteRunningChat(ctx, t, db, org.ID, user, model, workerID, "relay-snapshot") + staleChat := chat + staleChat.Status = database.ChatStatusWaiting + staleChat.WorkerID = uuid.NullUUID{} + staleChat.StartedAt = sql.NullTime{} + staleChat.HeartbeatAt = sql.NullTime{} - initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0) require.True(t, ok) t.Cleanup(cancel) @@ -458,15 +463,15 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) { require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts) - // The initial snapshot should still contain the status event - // from the OSS preamble. - var hasStatus bool + // The initial snapshot should contain the refreshed running status, + // not the stale waiting status passed into SubscribeAuthorized. + var snapshotStatus codersdk.ChatStatus for _, event := range initialSnapshot { - if event.Type == codersdk.ChatStreamEventTypeStatus { - hasStatus = true + if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil { + snapshotStatus = event.Status.Status } } - require.True(t, hasStatus, "initial snapshot should contain status event") + require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus) } func TestSubscribeRetryEventAcrossInstances(t *testing.T) {