mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/x/chatd): refresh chat status and bound subscriber reads on Subscribe (#24095)
Tightens the chat stream subscription path on a few related axes. None of these changes touch the steady-state event flow; they all concern the subscribe handshake. ## Motivation `Server.Subscribe` carries three responsibilities that were entangled: 1. Authorize the caller against the chat row. 2. Arm local + pubsub subscriptions before any DB reads (subscribe-first-then-query). 3. Build the initial snapshot from a fresh chat row, message history, and queue. When all three live in one function and share the request context, a few unfortunate behaviors fall out: - The HTTP handler's middleware already loaded and authorized the chat row, but `Subscribe(chatID)` discarded it and re-fetched on every WebSocket connection. - The chat row used to populate the initial `status` event was loaded *before* the pubsub subscription was armed, so a status transition that happened in that window was silently lost. - Control-path DB reads inherited whatever context the caller passed in. A caller without a deadline could wedge a subscriber goroutine indefinitely on a stalled DB. - A transient failure of the chat re-read collapsed the entire subscription instead of degrading gracefully. ## What changes **Split the auth boundary out into the type signature.** A new `SubscribeAuthorized(ctx, chat, ...)` takes the already-authorized row directly. The HTTP handler in `coderd/exp_chats.go` calls it with the chat row from `httpmw.ChatParam`, eliminating the redundant `GetChatByID`. `Subscribe(chatID)` is preserved as a thin wrapper for callers that don't have a chat row in hand (tests, internal callers); it does the auth lookup and delegates. **Re-read the chat after arming subscriptions.** Inside `SubscribeAuthorized`, after the local stream and pubsub subscriptions are active, we reload the chat row to populate the initial `status` event and any enterprise relay setup. Combined with the existing subscribe-first-then-query pattern, this closes the gap where a status transition between the middleware's load and the subscription arming would not appear in either the initial snapshot or a live notification. **Fall back to the middleware row on refresh failure.** If the post-subscription refresh fails (transient DB blip, brief pool exhaustion), we log a warning and reuse the row that proved authorization in the first place. Messages, queue, and pubsub are all independent of this row, so the stream still works; the initial `status` is just slightly stale and self-corrects via the next pubsub event. **Bound subscriber control-path DB reads.** A new `streamSubscriberControlFetchContext` helper applies a 5-second fallback timeout only when the caller has no deadline of their own. Used at the chat refresh, the initial queue load, and the queue-update goroutine following pubsub notifications. HTTP-driven callers pass through unchanged; background callers can no longer hang forever on a stalled DB and leak subscriber goroutines, pubsub subscriptions, and `chatStreams` entries.
This commit is contained in:
+1
-1
@@ -3261,7 +3261,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
|
||||
+115
-51
@@ -107,6 +107,9 @@ const (
|
||||
// cross-replica relay subscribers time to connect and
|
||||
// snapshot the buffer before it is garbage-collected.
|
||||
bufferRetainGracePeriod = 5 * time.Second
|
||||
// chatStreamControlFetchTimeout bounds subscriber-owned
|
||||
// control-path DB reads when the caller has no deadline.
|
||||
chatStreamControlFetchTimeout = 5 * time.Second
|
||||
|
||||
// streamJanitorInterval is how often sweepIdleStreams runs.
|
||||
// Worst-case retention is bufferRetainGracePeriod +
|
||||
@@ -4244,6 +4247,31 @@ func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
|
||||
// requesting subscriber while applying a fallback timeout when the caller has
|
||||
// no deadline.
|
||||
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
|
||||
}
|
||||
|
||||
func subscribeWithInitialError(chatID uuid.UUID, message string) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
events := make(chan codersdk.ChatStreamEvent)
|
||||
close(events)
|
||||
return []codersdk.ChatStreamEvent{{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatError{Message: message},
|
||||
}}, events, func() {}, true
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -4258,9 +4286,40 @@ func (p *Server) Subscribe(
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
p.logger.Warn(ctx, "failed to load chat for stream subscription",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
|
||||
}
|
||||
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
|
||||
}
|
||||
|
||||
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
|
||||
// updates. The passed chat row proves authorization, but SubscribeAuthorized
|
||||
// still reloads the chat after the stream subscriptions are armed so the
|
||||
// initial status and relay setup use fresh state.
|
||||
func (p *Server) SubscribeAuthorized(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
requestHeader http.Header,
|
||||
afterMessageID int64,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
chatID := chat.ID
|
||||
|
||||
// Subscribe to the local stream for message_parts and same-replica
|
||||
// persisted messages. Capture the current retry phase under the same
|
||||
@@ -4326,6 +4385,34 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-read the chat after the local/pubsub subscriptions are active so
|
||||
// the initial status event and any enterprise relay setup use fresh
|
||||
// state instead of the middleware-loaded row.
|
||||
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
|
||||
snapshotChat, err := func() (database.Chat, error) {
|
||||
defer refreshCancel()
|
||||
//nolint:gocritic // SubscribeAuthorized already validated the
|
||||
// caller; this refresh only loads the latest status/worker for
|
||||
// the already-authorized stream subscription.
|
||||
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
|
||||
}()
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
snapshotChat = chat
|
||||
}
|
||||
|
||||
// Build initial snapshot synchronously. The pubsub subscription
|
||||
// is already active so no notifications can be lost during this
|
||||
// window.
|
||||
@@ -4377,8 +4464,12 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
// Load initial queue.
|
||||
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
||||
// Load initial queue. Queue snapshots are intentionally not
|
||||
// singleflighted because a chat-scoped key cannot distinguish the
|
||||
// pre- and post-notification queue state.
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
|
||||
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to load initial queued messages",
|
||||
slog.Error(err),
|
||||
@@ -4397,44 +4488,24 @@ func (p *Server) Subscribe(
|
||||
})
|
||||
}
|
||||
|
||||
// Get initial chat state to determine if we need a relay.
|
||||
chat, chatErr := p.db.GetChatByID(ctx, chatID)
|
||||
|
||||
// Include the current chat status in the snapshot so the
|
||||
// frontend can gate message_part processing correctly from
|
||||
// the very first batch, without waiting for a separate REST
|
||||
// query.
|
||||
if chatErr != nil {
|
||||
p.logger.Error(ctx, "failed to load initial chat state",
|
||||
slog.Error(chatErr),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
|
||||
})
|
||||
} else {
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(chat.Status),
|
||||
},
|
||||
}
|
||||
// Prepend so the frontend sees the current stream phases
|
||||
// before any message_part events.
|
||||
prefix := []codersdk.ChatStreamEvent{statusEvent}
|
||||
if retryEvent != nil {
|
||||
prefix = append(prefix, *retryEvent)
|
||||
retryEvent = nil
|
||||
}
|
||||
initialSnapshot = append(prefix, initialSnapshot...)
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(snapshotChat.Status),
|
||||
},
|
||||
}
|
||||
|
||||
// Prepend so the frontend sees the current stream phases
|
||||
// before any message_part events.
|
||||
prefix := []codersdk.ChatStreamEvent{statusEvent}
|
||||
if retryEvent != nil {
|
||||
initialSnapshot = append(initialSnapshot, *retryEvent)
|
||||
prefix = append(prefix, *retryEvent)
|
||||
}
|
||||
initialSnapshot = append(prefix, initialSnapshot...)
|
||||
|
||||
// Track the highest durable message ID delivered to this subscriber,
|
||||
// whether it came from the initial DB snapshot, the same-replica local
|
||||
@@ -4444,18 +4515,17 @@ func (p *Server) Subscribe(
|
||||
lastMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// When an enterprise SubscribeFn is provided and the chat
|
||||
// lookup succeeded, call it to get relay events (message_parts
|
||||
// from remote replicas). OSS now owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding;
|
||||
// enterprise only manages relay dialing.
|
||||
// When an enterprise SubscribeFn is provided, call it to get relay events
|
||||
// (message_parts from remote replicas). OSS owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding; enterprise only
|
||||
// manages relay dialing.
|
||||
var relayEvents <-chan codersdk.ChatStreamEvent
|
||||
var statusNotifications chan StatusNotification
|
||||
if p.subscribeFn != nil && chatErr == nil {
|
||||
if p.subscribeFn != nil {
|
||||
statusNotifications = make(chan StatusNotification, 10)
|
||||
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
|
||||
ChatID: chatID,
|
||||
Chat: chat,
|
||||
Chat: snapshotChat,
|
||||
WorkerID: p.workerID,
|
||||
StatusNotifications: statusNotifications,
|
||||
RequestHeader: requestHeader,
|
||||
@@ -4600,7 +4670,9 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
if notify.QueueUpdate {
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if queueErr != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
|
||||
slog.F("chat_id", chatID),
|
||||
@@ -4676,14 +4748,6 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
return initialSnapshot, mergedEvents, cancel, true
|
||||
}
|
||||
|
||||
|
||||
@@ -1983,14 +1983,14 @@ func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -2026,14 +2026,14 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T)
|
||||
ChatID: chatID,
|
||||
Role: codersdk.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -2077,14 +2077,14 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 1,
|
||||
@@ -2126,14 +2126,14 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
@@ -2163,14 +2163,14 @@ func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -2200,12 +2200,13 @@ func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) {
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newBufferedSubscribeTestServer(t, db, chatID)
|
||||
@@ -2241,6 +2242,8 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
@@ -2249,7 +2252,6 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
|
||||
return nil, nil
|
||||
}),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
@@ -2275,12 +2277,13 @@ func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) {
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newBufferedSubscribeTestServer(t, db, chatID)
|
||||
@@ -2310,12 +2313,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) {
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newBufferedSubscribeTestServer(t, db, chatID)
|
||||
@@ -2350,12 +2354,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) {
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newBufferedSubscribeTestServer(t, db, chatID)
|
||||
@@ -2382,14 +2387,14 @@ func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -2422,14 +2427,14 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
@@ -2461,6 +2466,101 @@ func newTestRetryPayload() *codersdk.ChatStreamRetry {
|
||||
return payload
|
||||
}
|
||||
|
||||
func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
state := server.getOrCreateStreamState(chatID)
|
||||
state.mu.Lock()
|
||||
state.buffer = []codersdk.ChatStreamEvent{{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessageText("thinking"),
|
||||
},
|
||||
}}
|
||||
state.mu.Unlock()
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
)
|
||||
|
||||
initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
require.Len(t, initialSnapshot, 2)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type)
|
||||
require.NotNil(t, initialSnapshot[0].Status)
|
||||
require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type)
|
||||
require.NotNil(t, initialSnapshot[1].MessagePart)
|
||||
require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
||||
Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")})
|
||||
|
||||
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.False(t, ok)
|
||||
require.Nil(t, snapshot)
|
||||
require.Nil(t, events)
|
||||
require.Nil(t, cancel)
|
||||
|
||||
_, exists := server.chatStreams.Load(chatID)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := newSubscribeTestServer(t, db)
|
||||
|
||||
chatID := uuid.New()
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
||||
Return(database.Chat{}, xerrors.New("transient lookup failure"))
|
||||
|
||||
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, cancel)
|
||||
require.Len(t, snapshot, 1)
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type)
|
||||
require.Equal(t, chatID, snapshot[0].ChatID)
|
||||
require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message)
|
||||
|
||||
_, open := <-events
|
||||
require.False(t, open)
|
||||
|
||||
_, exists := server.chatStreams.Load(chatID)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -433,8 +433,13 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
||||
user, org, model := seedChatDependencies(t, db)
|
||||
|
||||
chat := seedRemoteRunningChat(ctx, t, db, org.ID, user, model, workerID, "relay-snapshot")
|
||||
staleChat := chat
|
||||
staleChat.Status = database.ChatStatusWaiting
|
||||
staleChat.WorkerID = uuid.NullUUID{}
|
||||
staleChat.StartedAt = sql.NullTime{}
|
||||
staleChat.HeartbeatAt = sql.NullTime{}
|
||||
|
||||
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@@ -458,15 +463,15 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
||||
|
||||
require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts)
|
||||
|
||||
// The initial snapshot should still contain the status event
|
||||
// from the OSS preamble.
|
||||
var hasStatus bool
|
||||
// The initial snapshot should contain the refreshed running status,
|
||||
// not the stale waiting status passed into SubscribeAuthorized.
|
||||
var snapshotStatus codersdk.ChatStatus
|
||||
for _, event := range initialSnapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus {
|
||||
hasStatus = true
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
||||
snapshotStatus = event.Status.Status
|
||||
}
|
||||
}
|
||||
require.True(t, hasStatus, "initial snapshot should contain status event")
|
||||
require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus)
|
||||
}
|
||||
|
||||
func TestSubscribeRetryEventAcrossInstances(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user