diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index db8745652a..67e219bbfd 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub. func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() apiKey := httpmw.APIKey(r) + logger := api.Logger.Named("chat_watcher") - sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r) + conn, err := websocket.Accept(rw, r, nil) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to open chat watch stream.", @@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) { }) return } - defer func() { - <-senderClosed - }() - cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID), - pubsub.HandleChatEvent( - func(ctx context.Context, payload pubsub.ChatEvent, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.HeartbeatClose(ctx, logger, cancel, conn) + + // The encoder is only written from the SubscribeWithErr callback, + // which delivers serially per subscription. Do not add a second + // write path without introducing synchronization. + encoder := json.NewEncoder(wsNetConn) + + cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID), + pubsub.HandleChatWatchEvent( + func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) { if err != nil { - api.Logger.Error(ctx, "chat event subscription error", slog.Error(err)) + logger.Error(ctx, "chat watch event subscription error", slog.Error(err)) return } - if err := sendEvent(codersdk.ServerSentEvent{ - Type: codersdk.ServerSentEventTypeData, - Data: payload, - }); err != nil { - api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err)) + if err := encoder.Encode(payload); err != nil { + logger.Debug(ctx, "failed to send chat watch event", slog.Error(err)) + cancel() + return } }, )) if err != nil { - if err := sendEvent(codersdk.ServerSentEvent{ - Type: codersdk.ServerSentEventTypeError, - Data: codersdk.Response{ - Message: "Internal error subscribing to chat events.", - Detail: err.Error(), - }, - }); err != nil { - api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err)) - } + logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.") return } defer cancelSubscribe() - // Send initial ping to signal the connection is ready. - if err := sendEvent(codersdk.ServerSentEvent{ - Type: codersdk.ServerSentEventTypePing, - }); err != nil { - api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err)) - } - - for { - select { - case <-ctx.Done(): - return - case <-senderClosed: - return - } - } + <-ctx.Done() } // EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to @@ -2176,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() chat := httpmw.ChatParam(r) chatID := chat.ID + logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID)) if api.chatDaemon == nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -2198,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { } } - sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r) + // Subscribe before accepting the WebSocket so that failures + // can still be reported as normal HTTP errors. + snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID) + // Subscribe only fails today when the receiver is nil, which + // the chatDaemon == nil guard above already catches. This is + // defensive against future Subscribe failure modes. + if !ok { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Chat streaming is not available.", + Detail: "Chat stream state is not configured.", + }) + return + } + defer cancelSub() + + conn, err := websocket.Accept(rw, r, nil) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to open chat stream.", @@ -2206,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { }) return } - snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID) - if !ok { - if err := sendEvent(codersdk.ServerSentEvent{ - Type: codersdk.ServerSentEventTypeError, - Data: codersdk.Response{ - Message: "Chat streaming is not available.", - Detail: "Chat stream state is not configured.", - }, - }); err != nil { - api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err)) - } - // Ensure the WebSocket is closed so senderClosed - // completes and the handler can return. - <-senderClosed - return - } - defer func() { - <-senderClosed - }() + + ctx, cancel := context.WithCancel(ctx) defer cancel() + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.HeartbeatClose(ctx, logger, cancel, conn) + // Mark the chat as read when the stream connects and again // when it disconnects so we avoid per-message API calls while // messages are actively streaming. api.markChatAsRead(ctx, chatID) defer api.markChatAsRead(context.WithoutCancel(ctx), chatID) + encoder := json.NewEncoder(wsNetConn) + sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error { if len(batch) == 0 { return nil } - return sendEvent(codersdk.ServerSentEvent{ - Type: codersdk.ServerSentEventTypeData, - Data: batch, - }) + return encoder.Encode(batch) } drainChatStreamBatch := func( @@ -2273,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { end = len(snapshot) } if err := sendChatStreamBatch(snapshot[start:end]); err != nil { - api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err)) + logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err)) return } } @@ -2282,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { select { case <-ctx.Done(): return - case <-senderClosed: - return case firstEvent, ok := <-events: if !ok { return @@ -2293,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) { chatStreamBatchSize, ) if err := sendChatStreamBatch(batch); err != nil { - api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err)) + logger.Debug(ctx, "failed to send chat stream event", slog.Error(err)) return } if streamClosed { @@ -2308,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() chat := httpmw.ChatParam(r) chatID := chat.ID + logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID)) if api.chatDaemon != nil { chat = api.chatDaemon.InterruptChat(ctx, chat) @@ -2321,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) { LastError: sql.NullString{}, }) if updateErr != nil { - api.Logger.Error(ctx, "failed to mark chat as waiting", - slog.F("chat_id", chatID), slog.Error(updateErr)) + logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr)) httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to interrupt chat.", Detail: updateErr.Error(), diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index afc892d59d..2f5170957e 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1114,17 +1114,6 @@ func TestWatchChats(t *testing.T) { require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") - type watchEvent struct { - Type codersdk.ServerSentEventType `json:"type"` - Data json.RawMessage `json:"data,omitempty"` - } - - var event watchEvent - err = wsjson.Read(ctx, conn, &event) - require.NoError(t, err) - require.Equal(t, codersdk.ServerSentEventTypePing, event.Type) - require.True(t, len(event.Data) == 0 || string(event.Data) == "null") - createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ Content: []codersdk.ChatInputPart{ { @@ -1136,25 +1125,16 @@ func TestWatchChats(t *testing.T) { require.NoError(t, err) for { - var update watchEvent - err = wsjson.Read(ctx, conn, &update) + var payload codersdk.ChatWatchEvent + err = wsjson.Read(ctx, conn, &payload) require.NoError(t, err) - if update.Type == codersdk.ServerSentEventTypePing { - continue - } - require.Equal(t, codersdk.ServerSentEventTypeData, update.Type) - - var payload coderdpubsub.ChatEvent - err = json.Unmarshal(update.Data, &payload) - require.NoError(t, err) - if payload.Kind == coderdpubsub.ChatEventKindCreated && + if payload.Kind == codersdk.ChatWatchEventKindCreated && payload.Chat.ID == createdChat.ID { break } } }) - t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) { t.Parallel() @@ -1174,18 +1154,6 @@ func TestWatchChats(t *testing.T) { require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") - type watchEvent struct { - Type codersdk.ServerSentEventType `json:"type"` - Data json.RawMessage `json:"data,omitempty"` - } - - // Skip the initial ping. - var event watchEvent - err = wsjson.Read(ctx, conn, &event) - require.NoError(t, err) - require.Equal(t, codersdk.ServerSentEventTypePing, event.Type) - require.True(t, len(event.Data) == 0 || string(event.Data) == "null") - createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ Content: []codersdk.ChatInputPart{ { @@ -1198,18 +1166,11 @@ func TestWatchChats(t *testing.T) { var got codersdk.Chat testutil.Eventually(ctx, t, func(_ context.Context) bool { - var update watchEvent - if readErr := wsjson.Read(ctx, conn, &update); readErr != nil { + var payload codersdk.ChatWatchEvent + if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil { return false } - if update.Type != codersdk.ServerSentEventTypeData { - return false - } - var payload coderdpubsub.ChatEvent - if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil { - return false - } - if payload.Kind == coderdpubsub.ChatEventKindCreated && + if payload.Kind == codersdk.ChatWatchEventKindCreated && payload.Chat.ID == createdChat.ID { got = payload.Chat return true @@ -1282,25 +1243,14 @@ func TestWatchChats(t *testing.T) { require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") - type watchEvent struct { - Type codersdk.ServerSentEventType `json:"type"` - Data json.RawMessage `json:"data,omitempty"` - } - - // Read the initial ping. - var ping watchEvent - err = wsjson.Read(ctx, conn, &ping) - require.NoError(t, err) - require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type) - // Publish a diff_status_change event via pubsub, // mimicking what PublishDiffStatusChange does after // it reads the diff status from the DB. dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus) - event := coderdpubsub.ChatEvent{ - Kind: coderdpubsub.ChatEventKindDiffStatusChange, + event := codersdk.ChatWatchEvent{ + Kind: codersdk.ChatWatchEventKindDiffStatusChange, Chat: codersdk.Chat{ ID: chat.ID, OwnerID: chat.OwnerID, @@ -1313,25 +1263,15 @@ func TestWatchChats(t *testing.T) { } payload, err := json.Marshal(event) require.NoError(t, err) - err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload) + err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload) require.NoError(t, err) - // Read events until we find the diff_status_change. for { - var update watchEvent - err = wsjson.Read(ctx, conn, &update) + var received codersdk.ChatWatchEvent + err = wsjson.Read(ctx, conn, &received) require.NoError(t, err) - if update.Type == codersdk.ServerSentEventTypePing { - continue - } - require.Equal(t, codersdk.ServerSentEventTypeData, update.Type) - - var received coderdpubsub.ChatEvent - err = json.Unmarshal(update.Data, &received) - require.NoError(t, err) - - if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange || + if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange || received.Chat.ID != chat.ID { continue } @@ -1350,7 +1290,6 @@ func TestWatchChats(t *testing.T) { break } }) - t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) { t.Parallel() @@ -1393,31 +1332,13 @@ func TestWatchChats(t *testing.T) { require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") - type watchEvent struct { - Type codersdk.ServerSentEventType `json:"type"` - Data json.RawMessage `json:"data,omitempty"` - } - - var ping watchEvent - err = wsjson.Read(ctx, conn, &ping) - require.NoError(t, err) - require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type) - - collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent { + collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent { t.Helper() - events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3) + events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3) for len(events) < 3 { - var update watchEvent - err = wsjson.Read(ctx, conn, &update) - require.NoError(t, err) - if update.Type == codersdk.ServerSentEventTypePing { - continue - } - require.Equal(t, codersdk.ServerSentEventTypeData, update.Type) - - var payload coderdpubsub.ChatEvent - err = json.Unmarshal(update.Data, &payload) + var payload codersdk.ChatWatchEvent + err = wsjson.Read(ctx, conn, &payload) require.NoError(t, err) if payload.Kind != expectedKind { continue @@ -1427,7 +1348,7 @@ func TestWatchChats(t *testing.T) { return events } - assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) { + assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) { t.Helper() require.Len(t, events, 3) @@ -1440,12 +1361,12 @@ func TestWatchChats(t *testing.T) { err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) - deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted) + deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted) assertLifecycleEvents(deletedEvents, true) err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) require.NoError(t, err) - createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated) + createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated) assertLifecycleEvents(createdEvents, false) }) diff --git a/coderd/pubsub/chatconfigevent.go b/coderd/pubsub/chatconfigevent.go index 896a2aaf82..60d495e157 100644 --- a/coderd/pubsub/chatconfigevent.go +++ b/coderd/pubsub/chatconfigevent.go @@ -14,7 +14,7 @@ import ( const ChatConfigEventChannel = "chat:config_change" // HandleChatConfigEvent wraps a typed callback for ChatConfigEvent -// messages, following the same pattern as HandleChatEvent. +// messages, following the same pattern as HandleChatWatchEvent. func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) { return func(ctx context.Context, message []byte, err error) { if err != nil { diff --git a/coderd/pubsub/chatevent.go b/coderd/pubsub/chatevent.go deleted file mode 100644 index 426d0e395a..0000000000 --- a/coderd/pubsub/chatevent.go +++ /dev/null @@ -1,49 +0,0 @@ -package pubsub - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/google/uuid" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/codersdk" -) - -func ChatEventChannel(ownerID uuid.UUID) string { - return fmt.Sprintf("chat:owner:%s", ownerID) -} - -func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) { - return func(ctx context.Context, message []byte, err error) { - if err != nil { - cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err)) - return - } - var payload ChatEvent - if err := json.Unmarshal(message, &payload); err != nil { - cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err)) - return - } - - cb(ctx, payload, err) - } -} - -type ChatEvent struct { - Kind ChatEventKind `json:"kind"` - Chat codersdk.Chat `json:"chat"` - ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"` -} - -type ChatEventKind string - -const ( - ChatEventKindStatusChange ChatEventKind = "status_change" - ChatEventKindTitleChange ChatEventKind = "title_change" - ChatEventKindCreated ChatEventKind = "created" - ChatEventKindDeleted ChatEventKind = "deleted" - ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change" - ChatEventKindActionRequired ChatEventKind = "action_required" -) diff --git a/coderd/pubsub/chatwatchevent.go b/coderd/pubsub/chatwatchevent.go new file mode 100644 index 0000000000..d844c88988 --- /dev/null +++ b/coderd/pubsub/chatwatchevent.go @@ -0,0 +1,36 @@ +package pubsub + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" +) + +// ChatWatchEventChannel returns the pubsub channel for chat +// lifecycle events scoped to a single user. +func ChatWatchEventChannel(ownerID uuid.UUID) string { + return fmt.Sprintf("chat:owner:%s", ownerID) +} + +// HandleChatWatchEvent wraps a typed callback for +// ChatWatchEvent messages delivered via pubsub. +func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) { + return func(ctx context.Context, message []byte, err error) { + if err != nil { + cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err)) + return + } + var payload codersdk.ChatWatchEvent + if err := json.Unmarshal(message, &payload); err != nil { + cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err)) + return + } + + cb(ctx, payload, err) + } +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 2cd6a3a7fc..6a77fdee8d 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -996,7 +996,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C return database.Chat{}, txErr } - p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil) + p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil) p.signalWake() return chat, nil } @@ -1158,7 +1158,7 @@ func (p *Server) SendMessage( p.publishMessage(opts.ChatID, result.Message) p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID) - p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil) p.signalWake() return result, nil } @@ -1301,7 +1301,7 @@ func (p *Server) EditMessage( QueueUpdate: true, }) p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID) - p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil) p.signalWake() return result, nil @@ -1355,10 +1355,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error { if interrupted { p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID) - p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil) } - p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted) + p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted) return nil } @@ -1373,7 +1373,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error { ctx, chat.ID, "unarchive", - coderdpubsub.ChatEventKindCreated, + codersdk.ChatWatchEventKindCreated, p.db.UnarchiveChatByID, ) } @@ -1382,7 +1382,7 @@ func (p *Server) applyChatLifecycleTransition( ctx context.Context, chatID uuid.UUID, action string, - kind coderdpubsub.ChatEventKind, + kind codersdk.ChatWatchEventKind, transition func(context.Context, uuid.UUID) ([]database.Chat, error), ) error { updatedChats, err := transition(ctx, chatID) @@ -1545,7 +1545,7 @@ func (p *Server) PromoteQueued( }) p.publishMessage(opts.ChatID, promoted) p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID) - p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) p.signalWake() return result, nil @@ -2092,7 +2092,7 @@ func (p *Server) regenerateChatTitleWithStore( return updatedChat, nil } - p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil) + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil) return updatedChat, nil } @@ -2347,7 +2347,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database return database.Chat{}, err } p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID) - p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) return updatedChat, nil } @@ -3627,7 +3627,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C } // publishChatPubsubEvents broadcasts a lifecycle event for each affected chat. -func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) { +func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) { for _, chat := range chats { p.publishChatPubsubEvent(chat, kind, nil) } @@ -3635,7 +3635,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsu // publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL // pubsub so that all replicas can push updates to watching clients. -func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) { +func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) { if p.pubsub == nil { return } @@ -3647,7 +3647,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch if diffStatus != nil { sdkChat.DiffStatus = diffStatus } - event := coderdpubsub.ChatEvent{ + event := codersdk.ChatWatchEvent{ Kind: kind, Chat: sdkChat, } @@ -3659,7 +3659,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch ) return } - if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil { + if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil { p.logger.Error(context.Background(), "failed to publish chat pubsub event", slog.F("chat_id", chat.ID), slog.F("kind", kind), @@ -3692,8 +3692,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo toolCalls := pendingToStreamToolCalls(pending) sdkChat := db2sdk.Chat(chat, nil, nil) - event := coderdpubsub.ChatEvent{ - Kind: coderdpubsub.ChatEventKindActionRequired, + event := codersdk.ChatWatchEvent{ + Kind: codersdk.ChatWatchEventKindActionRequired, Chat: sdkChat, ToolCalls: toolCalls, } @@ -3705,7 +3705,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo ) return } - if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil { + if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil { p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event", slog.F("chat_id", chat.ID), slog.Error(err), @@ -3733,7 +3733,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) } sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus) - p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus) + p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus) return nil } @@ -4215,7 +4215,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { if title, ok := generatedTitle.Load(); ok { updatedChat.Title = title } - p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil) + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) // When the chat is parked in requires_action, // publish the stream event and global pubsub event diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index bf5ef8eaec..71de8a2a93 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -71,14 +71,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { updatedChat.Title = wantTitle messageEvents := make(chan struct { - payload coderdpubsub.ChatEvent + payload codersdk.ChatWatchEvent err error }, 1) cancelSub, err := pubsub.SubscribeWithErr( - coderdpubsub.ChatEventChannel(ownerID), - coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) { + coderdpubsub.ChatWatchEventChannel(ownerID), + coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) { messageEvents <- struct { - payload coderdpubsub.ChatEvent + payload codersdk.ChatWatchEvent err error }{payload: payload, err: err} }), @@ -184,7 +184,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { select { case event := <-messageEvents: require.NoError(t, event.err) - require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind) + require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind) require.Equal(t, chatID, event.payload.Chat.ID) require.Equal(t, wantTitle, event.payload.Chat.Title) case <-time.After(time.Second): @@ -234,14 +234,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t unlockedChat.StartedAt = sql.NullTime{} messageEvents := make(chan struct { - payload coderdpubsub.ChatEvent + payload codersdk.ChatWatchEvent err error }, 1) cancelSub, err := pubsub.SubscribeWithErr( - coderdpubsub.ChatEventChannel(ownerID), - coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) { + coderdpubsub.ChatWatchEventChannel(ownerID), + coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) { messageEvents <- struct { - payload coderdpubsub.ChatEvent + payload codersdk.ChatWatchEvent err error }{payload: payload, err: err} }), @@ -373,7 +373,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t select { case event := <-messageEvents: require.NoError(t, event.err) - require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind) + require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind) require.Equal(t, chatID, event.payload.Chat.ID) require.Equal(t, wantTitle, event.payload.Chat.Title) case <-time.After(time.Second): diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index 9fadfcff37..b82249a23c 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -21,7 +21,6 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" - coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/coderd/x/chatd/chatretry" @@ -160,7 +159,7 @@ func (p *Server) maybeGenerateChatTitle( } chat.Title = title generatedTitle.Store(title) - p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil) + p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil) return } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 330c22029a..f0f5211f0e 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -574,7 +574,7 @@ func (p *Server) createChildSubagentChatWithOptions( return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr) } - p.publishChatPubsubEvent(child, coderdpubsub.ChatEventKindCreated, nil) + p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil) p.signalWake() return child, nil } diff --git a/codersdk/chats.go b/codersdk/chats.go index 11062ed052..676c66da88 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1130,11 +1130,6 @@ type ChatStreamEvent struct { ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"` } -type chatStreamEnvelope struct { - Type ServerSentEventType `json:"type"` - Data json.RawMessage `json:"data,omitempty"` -} - // ChatCostSummaryOptions are optional query parameters for GetChatCostSummary. type ChatCostSummaryOptions struct { StartDate time.Time @@ -1987,8 +1982,8 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o }() for { - var envelope chatStreamEnvelope - if err := wsjson.Read(streamCtx, conn, &envelope); err != nil { + var batch []ChatStreamEvent + if err := wsjson.Read(streamCtx, conn, &batch); err != nil { if streamCtx.Err() != nil { return } @@ -2005,61 +2000,10 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o return } - switch envelope.Type { - case ServerSentEventTypePing: - continue - case ServerSentEventTypeData: - var batch []ChatStreamEvent - decodeErr := json.Unmarshal(envelope.Data, &batch) - if decodeErr == nil { - for _, streamedEvent := range batch { - if !send(streamedEvent) { - return - } - } - continue - } - - { - _ = send(ChatStreamEvent{ - Type: ChatStreamEventTypeError, - Error: &ChatStreamError{ - Message: fmt.Sprintf( - "decode chat stream event batch: %v", - decodeErr, - ), - }, - }) + for _, event := range batch { + if !send(event) { return } - case ServerSentEventTypeError: - message := "chat stream returned an error" - if len(envelope.Data) > 0 { - var response Response - if err := json.Unmarshal(envelope.Data, &response); err == nil { - message = formatChatStreamResponseError(response) - } else { - trimmed := strings.TrimSpace(string(envelope.Data)) - if trimmed != "" { - message = trimmed - } - } - } - _ = send(ChatStreamEvent{ - Type: ChatStreamEventTypeError, - Error: &ChatStreamError{ - Message: message, - }, - }) - return - default: - _ = send(ChatStreamEvent{ - Type: ChatStreamEventTypeError, - Error: &ChatStreamError{ - Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type), - }, - }) - return } } }() @@ -2098,8 +2042,8 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv }() for { - var envelope chatStreamEnvelope - if err := wsjson.Read(streamCtx, conn, &envelope); err != nil { + var event ChatWatchEvent + if err := wsjson.Read(streamCtx, conn, &event); err != nil { if streamCtx.Err() != nil { return } @@ -2110,23 +2054,10 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv return } - switch envelope.Type { - case ServerSentEventTypePing: - continue - case ServerSentEventTypeData: - var event ChatWatchEvent - if err := json.Unmarshal(envelope.Data, &event); err != nil { - return - } - select { - case <-streamCtx.Done(): - return - case events <- event: - } - case ServerSentEventTypeError: - return - default: + select { + case <-streamCtx.Done(): return + case events <- event: } } }() @@ -2478,21 +2409,6 @@ func (c *ExperimentalClient) GetChatsByWorkspace(ctx context.Context, workspaceI return result, json.NewDecoder(res.Body).Decode(&result) } -func formatChatStreamResponseError(response Response) string { - message := strings.TrimSpace(response.Message) - detail := strings.TrimSpace(response.Detail) - switch { - case message == "" && detail == "": - return "chat stream returned an error" - case message == "": - return detail - case detail == "": - return message - default: - return fmt.Sprintf("%s: %s", message, detail) - } -} - // PRInsightsResponse is the response from the PR insights endpoint. type PRInsightsResponse struct { Summary PRInsightsSummary `json:"summary"` diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 31eae1e2ba..3cb622454b 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -145,7 +145,7 @@ export const watchWorkspace = ( export const watchChat = ( chatId: string, afterMessageId?: number, -): OneWayWebSocketApi => { +): OneWayWebSocketApi => { const params = new URLSearchParams(); if (afterMessageId !== undefined && afterMessageId > 0) { params.set("after_id", afterMessageId.toString()); @@ -161,7 +161,7 @@ export const watchChat = ( }); }; -export const watchChats = (): OneWayWebSocket => { +export const watchChats = (): OneWayWebSocket => { const searchParams: Record = {}; const token = API.getSessionToken(); if (token) { diff --git a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx index be347cb264..b47fe44902 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx @@ -198,14 +198,6 @@ const buildQueries = ( ]; }; -/** - * Wrap a chat stream event payload in the JSON string format that - * OneWayWebSocket expects when receiving a WebSocket message event. - * The result is a `ServerSentEvent` of type `"data"` serialised to JSON. - */ -const wrapSSE = (payload: unknown): string => - JSON.stringify({ type: "data", data: payload }); - // --------------------------------------------------------------------------- // Meta // --------------------------------------------------------------------------- @@ -856,17 +848,20 @@ export const StreamedSubagentTitle: Story = { "/chats/": [ { event: "message", - data: wrapSSE({ - type: "message_part", - message_part: { - part: { - type: "tool-call", - tool_call_id: "tool-subagent-stream-1", - tool_name: "spawn_agent", - args_delta: '{"title":"Streamed Child"', + data: JSON.stringify([ + { + type: "message_part", + chat_id: CHAT_ID, + message_part: { + part: { + type: "tool-call", + tool_call_id: "tool-subagent-stream-1", + tool_name: "spawn_agent", + args_delta: '{"title":"Streamed Child"', + }, }, }, - }), + ] satisfies TypesGen.ChatStreamEvent[]), }, ], }, @@ -1150,15 +1145,18 @@ export const StreamedReasoning: Story = { "/chats/": [ { event: "message", - data: wrapSSE({ - type: "message_part", - message_part: { - part: { - type: "reasoning", - text: "Streaming reasoning body", + data: JSON.stringify([ + { + type: "message_part", + chat_id: CHAT_ID, + message_part: { + part: { + type: "reasoning", + text: "Streaming reasoning body", + }, }, }, - }), + ] satisfies TypesGen.ChatStreamEvent[]), }, ], }, @@ -1230,18 +1228,20 @@ export const WithWaitAgentComputerUseVNC: Story = { "/chats/": [ { event: "message", - data: wrapSSE({ - type: "message_part", - chat_id: CHAT_ID, - message_part: { - part: { - type: "tool-call", - tool_call_id: "tool-wait-desktop", - tool_name: "wait_agent", - args_delta: '{"chat_id":"desktop-child-1"}', + data: JSON.stringify([ + { + type: "message_part", + chat_id: CHAT_ID, + message_part: { + part: { + type: "tool-call", + tool_call_id: "tool-wait-desktop", + tool_name: "wait_agent", + args_delta: '{"chat_id":"desktop-child-1"}', + }, }, }, - }), + ] satisfies TypesGen.ChatStreamEvent[]), }, ], }, diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index 541151d12f..a2ffa7535b 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -51,7 +51,6 @@ import { chatDetailErrorsEqual, } from "./utils/usageLimitMessage"; -// Type guard for SSE events from the chat list watch endpoint. // Shallow-compare two ChatDiffStatus objects by their meaningful // fields, ignoring refreshed_at/stale_at which change on every poll. function diffStatusEqual( @@ -75,19 +74,6 @@ function diffStatusEqual( ); } -function isChatListSSEEvent( - data: unknown, -): data is { kind: string; chat: TypesGen.Chat } { - if (typeof data !== "object" || data === null) return false; - const obj = data as Record; - return ( - typeof obj.kind === "string" && - typeof obj.chat === "object" && - obj.chat !== null && - "id" in obj.chat - ); -} - export type { AgentsOutletContext } from "./AgentsPageView"; const AgentsPage: FC = () => { @@ -495,14 +481,7 @@ const AgentsPage: FC = () => { console.warn("Failed to parse chat watch event:", event.parseError); return; } - const sse = event.parsedMessage; - if (sse?.type !== "data" || !sse.data) { - return; - } - if (!isChatListSSEEvent(sse.data)) { - return; - } - const chatEvent = sse.data; + const chatEvent = event.parsedMessage; const updatedChat = chatEvent.chat; // Read the previous status from the infinite chat list // cache before we write the update below. The per-chat diff --git a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.test.tsx b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.test.tsx index 6093ae74ef..01105e1478 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/chatStore.test.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/chatStore.test.tsx @@ -55,7 +55,7 @@ vi.mock("#/api/api", () => ({ })); type MessageListener = ( - payload: OneWayMessageEvent, + payload: OneWayMessageEvent, ) => void; type ErrorListener = (payload: Event) => void; type OpenListener = (payload: Event) => void; @@ -67,6 +67,7 @@ type MockSocketHelpers = { emitOpen: () => void; emitData: (event: TypesGen.ChatStreamEvent) => void; emitDataBatch: (events: readonly TypesGen.ChatStreamEvent[]) => void; + emitParseError: () => void; emitError: () => void; emitClose: () => void; }; @@ -143,26 +144,30 @@ const createMockSocket = (): MockSocket => { removeEventListener, close: vi.fn(), emitData: (event) => { - const payload: OneWayMessageEvent = { + const payload: OneWayMessageEvent = { sourceEvent: {} as MessageEvent, parseError: undefined, - parsedMessage: { - type: "data", - data: event, - }, + parsedMessage: [event], }; for (const listener of messageListeners) { listener(payload); } }, emitDataBatch: (events) => { - const payload: OneWayMessageEvent = { + const payload: OneWayMessageEvent = { sourceEvent: {} as MessageEvent, parseError: undefined, - parsedMessage: { - type: "data", - data: events, - }, + parsedMessage: events as TypesGen.ChatStreamEvent[], + }; + for (const listener of messageListeners) { + listener(payload); + } + }, + emitParseError: () => { + const payload: OneWayMessageEvent = { + sourceEvent: {} as MessageEvent, + parseError: new Error("bad json"), + parsedMessage: undefined, }; for (const listener of messageListeners) { listener(payload); @@ -4209,3 +4214,215 @@ describe("store/cache desync protection", () => { }); }); }); + +describe("parse errors", () => { + it("surfaces parseError as streamError", async () => { + immediateAnimationFrame(); + + const chatID = "chat-parse-error"; + const mockSocket = createMockSocket(); + mockWatchChatReturn(mockSocket); + + const queryClient = createTestQueryClient(); + const wrapper = ({ children }: PropsWithChildren) => ( + {children} + ); + const setChatErrorReason = vi.fn(); + const clearChatErrorReason = vi.fn(); + + const { result } = renderHook( + () => { + const { store } = useChatStore({ + chatID, + chatMessages: [], + chatRecord: makeChat(chatID), + chatMessagesData: { + messages: [], + queued_messages: [], + has_more: false, + }, + chatQueuedMessages: [], + setChatErrorReason, + clearChatErrorReason, + }); + return { + streamError: useChatSelector(store, selectStreamError), + chatStatus: useChatSelector(store, selectChatStatus), + }; + }, + { wrapper }, + ); + + await waitFor(() => { + expect(watchChat).toHaveBeenCalledWith(chatID, undefined); + }); + + act(() => { + mockSocket.emitParseError(); + }); + + await waitFor(() => { + expect(result.current.streamError).toEqual({ + kind: "generic", + message: "Failed to parse chat stream update.", + }); + }); + expect(result.current.chatStatus).not.toBe("error"); + }); + + it("does not corrupt in-progress stream state", async () => { + immediateAnimationFrame(); + + const chatID = "chat-parse-no-corrupt"; + const existingMessage = makeMessage(chatID, 1, "user", "hello"); + const mockSocket = createMockSocket(); + mockWatchChatReturn(mockSocket); + + const queryClient = createTestQueryClient(); + const wrapper = ({ children }: PropsWithChildren) => ( + {children} + ); + const setChatErrorReason = vi.fn(); + const clearChatErrorReason = vi.fn(); + + const { result } = renderHook( + () => { + const { store } = useChatStore({ + chatID, + chatMessages: [existingMessage], + chatRecord: makeChat(chatID), + chatMessagesData: { + messages: [existingMessage], + queued_messages: [], + has_more: false, + }, + chatQueuedMessages: [], + setChatErrorReason, + clearChatErrorReason, + }); + return { + streamState: useChatSelector(store, selectStreamState), + streamError: useChatSelector(store, selectStreamError), + }; + }, + { wrapper }, + ); + + await waitFor(() => { + expect(watchChat).toHaveBeenCalledWith(chatID, 1); + }); + + // Build up some stream state first. + act(() => { + mockSocket.emitData({ + type: "message_part", + chat_id: chatID, + message_part: { + role: "assistant", + part: { type: "text", text: "partial response" }, + }, + }); + }); + + await waitFor(() => { + expect(result.current.streamState?.blocks).toEqual([ + { type: "response", text: "partial response" }, + ]); + }); + + // Fire a parse error and verify the existing stream blocks survive. + act(() => { + mockSocket.emitParseError(); + }); + + await waitFor(() => { + expect(result.current.streamError).toEqual({ + kind: "generic", + message: "Failed to parse chat stream update.", + }); + }); + expect(result.current.streamState?.blocks).toEqual([ + { type: "response", text: "partial response" }, + ]); + }); + + it("continues processing after parse error", async () => { + immediateAnimationFrame(); + + const chatID = "chat-parse-recover"; + const existingMessage = makeMessage(chatID, 1, "user", "hello"); + const mockSocket = createMockSocket(); + mockWatchChatReturn(mockSocket); + + const queryClient = createTestQueryClient(); + const wrapper = ({ children }: PropsWithChildren) => ( + {children} + ); + const setChatErrorReason = vi.fn(); + const clearChatErrorReason = vi.fn(); + + const { result } = renderHook( + () => { + const { store } = useChatStore({ + chatID, + chatMessages: [existingMessage], + chatRecord: makeChat(chatID), + chatMessagesData: { + messages: [existingMessage], + queued_messages: [], + has_more: false, + }, + chatQueuedMessages: [], + setChatErrorReason, + clearChatErrorReason, + }); + return { + streamState: useChatSelector(store, selectStreamState), + streamError: useChatSelector(store, selectStreamError), + }; + }, + { wrapper }, + ); + + await waitFor(() => { + expect(watchChat).toHaveBeenCalledWith(chatID, 1); + }); + + // Trigger a parse error first. + act(() => { + mockSocket.emitParseError(); + }); + + await waitFor(() => { + expect(result.current.streamError).toEqual({ + kind: "generic", + message: "Failed to parse chat stream update.", + }); + }); + + // Send a valid message_part after the parse error. + act(() => { + mockSocket.emitData({ + type: "message_part", + chat_id: chatID, + message_part: { + role: "assistant", + part: { type: "text", text: "recovered" }, + }, + }); + }); + + // The stream should process the new part normally. + await waitFor(() => { + expect(result.current.streamState?.blocks).toEqual([ + { type: "response", text: "recovered" }, + ]); + }); + + // streamError is sticky and is not cleared by valid messages. + expect(result.current.streamError).toEqual({ + kind: "generic", + message: "Failed to parse chat stream update.", + }); + }); +}); diff --git a/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts b/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts index ab660ec4b7..1dad72051e 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts +++ b/site/src/pages/AgentsPage/components/ChatConversation/useChatStore.ts @@ -6,7 +6,6 @@ import type * as TypesGen from "#/api/typesGenerated"; import type { OneWayMessageEvent } from "#/utils/OneWayWebSocket"; import { createReconnectingWebSocket } from "#/utils/reconnectingWebSocket"; import type { ChatDetailError } from "../../utils/usageLimitMessage"; -import { asNumber, asString } from "../ChatElements/runtimeTypeUtils"; import { type ChatStore, type ChatStoreState, @@ -17,50 +16,24 @@ import { } from "./chatStore"; import type { RetryState } from "./types"; -const isChatStreamEvent = (data: unknown): data is TypesGen.ChatStreamEvent => - typeof data === "object" && - data !== null && - "type" in data && - typeof (data as Record).type === "string"; - -const isChatStreamEventArray = ( - data: unknown, -): data is TypesGen.ChatStreamEvent[] => - Array.isArray(data) && data.every(isChatStreamEvent); - -const toChatStreamEvents = (data: unknown): TypesGen.ChatStreamEvent[] => { - if (isChatStreamEvent(data)) { - return [data]; - } - if (isChatStreamEventArray(data)) { - return data; - } - return []; -}; - const normalizeChatDetailError = ( - error: TypesGen.ChatStreamError | Record | undefined, + error: TypesGen.ChatStreamError | undefined, ): ChatDetailError => ({ - message: asString(error?.message).trim() || "Chat processing failed.", - kind: asString(error?.kind).trim() || "generic", - provider: asString(error?.provider).trim() || undefined, - retryable: - typeof error?.retryable === "boolean" ? error.retryable : undefined, - statusCode: asNumber(error?.status_code), + message: error?.message.trim() || "Chat processing failed.", + kind: error?.kind?.trim() || "generic", + provider: error?.provider?.trim() || undefined, + retryable: error?.retryable, + statusCode: error?.status_code, }); -const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => { - const delayMs = asNumber(retry.delay_ms); - const retryingAt = asString(retry.retrying_at).trim() || undefined; - return { - attempt: Math.max(1, asNumber(retry.attempt) ?? 1), - error: asString(retry.error).trim() || "Retrying request shortly.", - kind: asString(retry.kind).trim() || "generic", - provider: asString(retry.provider).trim() || undefined, - ...(delayMs !== undefined ? { delayMs } : {}), - ...(retryingAt ? { retryingAt } : {}), - }; -}; +const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => ({ + attempt: Math.max(1, retry.attempt), + error: retry.error.trim() || "Retrying request shortly.", + kind: retry.kind?.trim() || "generic", + provider: retry.provider?.trim() || undefined, + delayMs: retry.delay_ms, + retryingAt: retry.retrying_at.trim() || undefined, +}); const shouldSurfaceReconnectState = (state: ChatStoreState): boolean => state.streamError === null && @@ -419,7 +392,7 @@ export const useChatStore = ( }; const handleMessage = ( - payload: OneWayMessageEvent, + payload: OneWayMessageEvent, ) => { if (disposed) { return; @@ -431,11 +404,8 @@ export const useChatStore = ( }); return; } - if (payload.parsedMessage.type !== "data") { - return; - } - const streamEvents = toChatStreamEvents(payload.parsedMessage.data); + const streamEvents = payload.parsedMessage; if (streamEvents.length === 0) { return; }