fix(coderd/x/chatd): refresh chat status and bound subscriber reads on Subscribe (#24095)

Tightens the chat stream subscription path on a few related axes. None
of these changes touch the steady-state event flow; they all concern the
subscribe handshake.

## Motivation

`Server.Subscribe` carries three responsibilities that were entangled:

1. Authorize the caller against the chat row.
2. Arm local + pubsub subscriptions before any DB reads
(subscribe-first-then-query).
3. Build the initial snapshot from a fresh chat row, message history,
and queue.

When all three live in one function and share the request context, a few
unfortunate behaviors fall out:

- The HTTP handler's middleware already loaded and authorized the chat
row, but `Subscribe(chatID)` discarded it and re-fetched on every
WebSocket connection.
- The chat row used to populate the initial `status` event was loaded
*before* the pubsub subscription was armed, so a status transition that
happened in that window was silently lost.
- Control-path DB reads inherited whatever context the caller passed in.
A caller without a deadline could wedge a subscriber goroutine
indefinitely on a stalled DB.
- A transient failure of the chat re-read collapsed the entire
subscription instead of degrading gracefully.

## What changes

**Split the auth boundary out into the type signature.** A new
`SubscribeAuthorized(ctx, chat, ...)` takes the already-authorized row
directly. The HTTP handler in `coderd/exp_chats.go` calls it with the
chat row from `httpmw.ChatParam`, eliminating the redundant
`GetChatByID`. `Subscribe(chatID)` is preserved as a thin wrapper for
callers that don't have a chat row in hand (tests, internal callers); it
does the auth lookup and delegates.

**Re-read the chat after arming subscriptions.** Inside
`SubscribeAuthorized`, after the local stream and pubsub subscriptions
are active, we reload the chat row to populate the initial `status`
event and any enterprise relay setup. Combined with the existing
subscribe-first-then-query pattern, this closes the gap where a status
transition between the middleware's load and the subscription arming
would not appear in either the initial snapshot or a live notification.

**Fall back to the middleware row on refresh failure.** If the
post-subscription refresh fails (transient DB blip, brief pool
exhaustion), we log a warning and reuse the row that proved
authorization in the first place. Messages, queue, and pubsub are all
independent of this row, so the stream still works; the initial `status`
is just slightly stale and self-corrects via the next pubsub event.

**Bound subscriber control-path DB reads.** A new
`streamSubscriberControlFetchContext` helper applies a 5-second fallback
timeout only when the caller has no deadline of their own. Used at the
chat refresh, the initial queue load, and the queue-update goroutine
following pubsub notifications. HTTP-driven callers pass through
unchanged; background callers can no longer hang forever on a stalled DB
and leak subscriber goroutines, pubsub subscriptions, and `chatStreams`
entries.
This commit is contained in:
Ethan
2026-05-06 14:29:53 +10:00
committed by GitHub
parent 0dc4c34efc
commit e5c7fdff86
4 changed files with 247 additions and 78 deletions
+1 -1
View File
@@ -3261,7 +3261,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
// Subscribe before accepting the WebSocket so that failures
// can still be reported as normal HTTP errors.
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
// Subscribe only fails today when the receiver is nil, which
// the chatDaemon == nil guard above already catches. This is
// defensive against future Subscribe failure modes.
+115 -51
View File
@@ -107,6 +107,9 @@ const (
// cross-replica relay subscribers time to connect and
// snapshot the buffer before it is garbage-collected.
bufferRetainGracePeriod = 5 * time.Second
// chatStreamControlFetchTimeout bounds subscriber-owned
// control-path DB reads when the caller has no deadline.
chatStreamControlFetchTimeout = 5 * time.Second
// streamJanitorInterval is how often sweepIdleStreams runs.
// Worst-case retention is bufferRetainGracePeriod +
@@ -4244,6 +4247,31 @@ func (p *Server) heartbeatTick(ctx context.Context) {
}
}
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
// requesting subscriber while applying a fallback timeout when the caller has
// no deadline.
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
}
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
}
func subscribeWithInitialError(chatID uuid.UUID, message string) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
events := make(chan codersdk.ChatStreamEvent)
close(events)
return []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{Message: message},
}}, events, func() {}, true
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
@@ -4258,9 +4286,40 @@ func (p *Server) Subscribe(
if p == nil {
return nil, nil, nil, false
}
if ctx == nil {
ctx = context.Background()
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
if dbauthz.IsNotAuthorizedError(err) {
return nil, nil, nil, false
}
p.logger.Warn(ctx, "failed to load chat for stream subscription",
slog.F("chat_id", chatID),
slog.Error(err),
)
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
}
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
}
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
// updates. The passed chat row proves authorization, but SubscribeAuthorized
// still reloads the chat after the stream subscriptions are armed so the
// initial status and relay setup use fresh state.
func (p *Server) SubscribeAuthorized(
ctx context.Context,
chat database.Chat,
requestHeader http.Header,
afterMessageID int64,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
if p == nil {
return nil, nil, nil, false
}
chatID := chat.ID
// Subscribe to the local stream for message_parts and same-replica
// persisted messages. Capture the current retry phase under the same
@@ -4326,6 +4385,34 @@ func (p *Server) Subscribe(
}
}
cancel := func() {
mergedCancel()
for _, cancelFn := range allCancels {
if cancelFn != nil {
cancelFn()
}
}
}
// Re-read the chat after the local/pubsub subscriptions are active so
// the initial status event and any enterprise relay setup use fresh
// state instead of the middleware-loaded row.
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
snapshotChat, err := func() (database.Chat, error) {
defer refreshCancel()
//nolint:gocritic // SubscribeAuthorized already validated the
// caller; this refresh only loads the latest status/worker for
// the already-authorized stream subscription.
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
}()
if err != nil {
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
slog.F("chat_id", chatID),
slog.Error(err),
)
snapshotChat = chat
}
// Build initial snapshot synchronously. The pubsub subscription
// is already active so no notifications can be lost during this
// window.
@@ -4377,8 +4464,12 @@ func (p *Server) Subscribe(
}
}
// Load initial queue.
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
// Load initial queue. Queue snapshots are intentionally not
// singleflighted because a chat-scoped key cannot distinguish the
// pre- and post-notification queue state.
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if err != nil {
p.logger.Error(ctx, "failed to load initial queued messages",
slog.Error(err),
@@ -4397,44 +4488,24 @@ func (p *Server) Subscribe(
})
}
// Get initial chat state to determine if we need a relay.
chat, chatErr := p.db.GetChatByID(ctx, chatID)
// Include the current chat status in the snapshot so the
// frontend can gate message_part processing correctly from
// the very first batch, without waiting for a separate REST
// query.
if chatErr != nil {
p.logger.Error(ctx, "failed to load initial chat state",
slog.Error(chatErr),
slog.F("chat_id", chatID),
)
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
})
} else {
statusEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{
Status: codersdk.ChatStatus(chat.Status),
},
}
// Prepend so the frontend sees the current stream phases
// before any message_part events.
prefix := []codersdk.ChatStreamEvent{statusEvent}
if retryEvent != nil {
prefix = append(prefix, *retryEvent)
retryEvent = nil
}
initialSnapshot = append(prefix, initialSnapshot...)
statusEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{
Status: codersdk.ChatStatus(snapshotChat.Status),
},
}
// Prepend so the frontend sees the current stream phases
// before any message_part events.
prefix := []codersdk.ChatStreamEvent{statusEvent}
if retryEvent != nil {
initialSnapshot = append(initialSnapshot, *retryEvent)
prefix = append(prefix, *retryEvent)
}
initialSnapshot = append(prefix, initialSnapshot...)
// Track the highest durable message ID delivered to this subscriber,
// whether it came from the initial DB snapshot, the same-replica local
@@ -4444,18 +4515,17 @@ func (p *Server) Subscribe(
lastMessageID = messages[len(messages)-1].ID
}
// When an enterprise SubscribeFn is provided and the chat
// lookup succeeded, call it to get relay events (message_parts
// from remote replicas). OSS now owns pubsub subscription,
// message catch-up, queue updates, and status forwarding;
// enterprise only manages relay dialing.
// When an enterprise SubscribeFn is provided, call it to get relay events
// (message_parts from remote replicas). OSS owns pubsub subscription,
// message catch-up, queue updates, and status forwarding; enterprise only
// manages relay dialing.
var relayEvents <-chan codersdk.ChatStreamEvent
var statusNotifications chan StatusNotification
if p.subscribeFn != nil && chatErr == nil {
if p.subscribeFn != nil {
statusNotifications = make(chan StatusNotification, 10)
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
ChatID: chatID,
Chat: chat,
Chat: snapshotChat,
WorkerID: p.workerID,
StatusNotifications: statusNotifications,
RequestHeader: requestHeader,
@@ -4600,7 +4670,9 @@ func (p *Server) Subscribe(
}
}
if notify.QueueUpdate {
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if queueErr != nil {
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
slog.F("chat_id", chatID),
@@ -4676,14 +4748,6 @@ func (p *Server) Subscribe(
}
}()
cancel := func() {
mergedCancel()
for _, cancelFn := range allCancels {
if cancelFn != nil {
cancelFn()
}
}
}
return initialSnapshot, mergedEvents, cancel, true
}
+119 -19
View File
@@ -1983,14 +1983,14 @@ func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -2026,14 +2026,14 @@ func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T)
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -2077,14 +2077,14 @@ func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
@@ -2126,14 +2126,14 @@ func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
@@ -2163,14 +2163,14 @@ func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -2200,12 +2200,13 @@ func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) {
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newBufferedSubscribeTestServer(t, db, chatID)
@@ -2241,6 +2242,8 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
server := newSubscribeTestServer(t, db)
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
@@ -2249,7 +2252,6 @@ func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
return nil, nil
}),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
@@ -2275,12 +2277,13 @@ func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) {
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newBufferedSubscribeTestServer(t, db, chatID)
@@ -2310,12 +2313,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) {
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newBufferedSubscribeTestServer(t, db, chatID)
@@ -2350,12 +2354,13 @@ func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) {
chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newBufferedSubscribeTestServer(t, db, chatID)
@@ -2382,14 +2387,14 @@ func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -2422,14 +2427,14 @@ func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
@@ -2461,6 +2466,101 @@ func newTestRetryPayload() *codersdk.ChatStreamRetry {
return payload
}
func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
state := server.getOrCreateStreamState(chatID)
state.mu.Lock()
state.buffer = []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("thinking"),
},
}}
state.mu.Unlock()
gomock.InOrder(
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
)
initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0)
require.True(t, ok)
defer cancel()
require.Len(t, initialSnapshot, 2)
require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type)
require.NotNil(t, initialSnapshot[0].Status)
require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status)
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type)
require.NotNil(t, initialSnapshot[1].MessagePart)
require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
db.EXPECT().GetChatByID(gomock.Any(), chatID).
Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")})
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.False(t, ok)
require.Nil(t, snapshot)
require.Nil(t, events)
require.Nil(t, cancel)
_, exists := server.chatStreams.Load(chatID)
require.False(t, exists)
}
func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := newSubscribeTestServer(t, db)
chatID := uuid.New()
db.EXPECT().GetChatByID(gomock.Any(), chatID).
Return(database.Chat{}, xerrors.New("transient lookup failure"))
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
require.NotNil(t, cancel)
require.Len(t, snapshot, 1)
require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type)
require.Equal(t, chatID, snapshot[0].ChatID)
require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message)
_, open := <-events
require.False(t, open)
_, exists := server.chatStreams.Load(chatID)
require.False(t, exists)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
+12 -7
View File
@@ -433,8 +433,13 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
user, org, model := seedChatDependencies(t, db)
chat := seedRemoteRunningChat(ctx, t, db, org.ID, user, model, workerID, "relay-snapshot")
staleChat := chat
staleChat.Status = database.ChatStatusWaiting
staleChat.WorkerID = uuid.NullUUID{}
staleChat.StartedAt = sql.NullTime{}
staleChat.HeartbeatAt = sql.NullTime{}
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
@@ -458,15 +463,15 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts)
// The initial snapshot should still contain the status event
// from the OSS preamble.
var hasStatus bool
// The initial snapshot should contain the refreshed running status,
// not the stale waiting status passed into SubscribeAuthorized.
var snapshotStatus codersdk.ChatStatus
for _, event := range initialSnapshot {
if event.Type == codersdk.ChatStreamEventTypeStatus {
hasStatus = true
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
snapshotStatus = event.Status.Status
}
}
require.True(t, hasStatus, "initial snapshot should contain status event")
require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus)
}
func TestSubscribeRetryEventAcrossInstances(t *testing.T) {