mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor: send raw typed payloads over chat WebSockets (#24148)
This commit is contained in:
+61
-67
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -2176,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2198,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2206,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2273,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2282,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2293,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2308,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2321,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
|
||||
+19
-98
@@ -1114,17 +1114,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1136,25 +1125,16 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1174,18 +1154,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1198,18 +1166,11 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1282,25 +1243,14 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1313,25 +1263,15 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1350,7 +1290,6 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1393,31 +1332,13 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1427,7 +1348,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1440,12 +1361,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
ChatEventKindActionRequired ChatEventKind = "action_required"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
+19
-19
@@ -996,7 +996,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return database.Chat{}, txErr
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return chat, nil
|
||||
}
|
||||
@@ -1158,7 +1158,7 @@ func (p *Server) SendMessage(
|
||||
|
||||
p.publishMessage(opts.ChatID, result.Message)
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
return result, nil
|
||||
}
|
||||
@@ -1301,7 +1301,7 @@ func (p *Server) EditMessage(
|
||||
QueueUpdate: true,
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1355,10 +1355,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
codersdk.ChatWatchEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
@@ -1382,7 +1382,7 @@ func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
kind codersdk.ChatWatchEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
@@ -1545,7 +1545,7 @@ func (p *Server) PromoteQueued(
|
||||
})
|
||||
p.publishMessage(opts.ChatID, promoted)
|
||||
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -2092,7 +2092,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2347,7 +2347,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
|
||||
return database.Chat{}, err
|
||||
}
|
||||
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -3627,7 +3627,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
@@ -3635,7 +3635,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsu
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
@@ -3647,7 +3647,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
if diffStatus != nil {
|
||||
sdkChat.DiffStatus = diffStatus
|
||||
}
|
||||
event := coderdpubsub.ChatEvent{
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: kind,
|
||||
Chat: sdkChat,
|
||||
}
|
||||
@@ -3659,7 +3659,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("kind", kind),
|
||||
@@ -3692,8 +3692,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
toolCalls := pendingToStreamToolCalls(pending)
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindActionRequired,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindActionRequired,
|
||||
Chat: sdkChat,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
@@ -3705,7 +3705,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
@@ -3733,7 +3733,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
|
||||
}
|
||||
|
||||
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4215,7 +4215,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
if title, ok := generatedTitle.Load(); ok {
|
||||
updatedChat.Title = title
|
||||
}
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
|
||||
// When the chat is parked in requires_action,
|
||||
// publish the stream event and global pubsub event
|
||||
|
||||
@@ -71,14 +71,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
updatedChat.Title = wantTitle
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -184,7 +184,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -234,14 +234,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
unlockedChat.StartedAt = sql.NullTime{}
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
payload codersdk.ChatWatchEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -373,7 +373,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -160,7 +159,7 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -574,7 +574,7 @@ func (p *Server) createChildSubagentChatWithOptions(
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(child, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return child, nil
|
||||
}
|
||||
|
||||
+9
-93
@@ -1130,11 +1130,6 @@ type ChatStreamEvent struct {
|
||||
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
|
||||
}
|
||||
|
||||
type chatStreamEnvelope struct {
|
||||
Type ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
|
||||
type ChatCostSummaryOptions struct {
|
||||
StartDate time.Time
|
||||
@@ -1987,8 +1982,8 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var batch []ChatStreamEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &batch); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2005,61 +2000,10 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var batch []ChatStreamEvent
|
||||
decodeErr := json.Unmarshal(envelope.Data, &batch)
|
||||
if decodeErr == nil {
|
||||
for _, streamedEvent := range batch {
|
||||
if !send(streamedEvent) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
{
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf(
|
||||
"decode chat stream event batch: %v",
|
||||
decodeErr,
|
||||
),
|
||||
},
|
||||
})
|
||||
for _, event := range batch {
|
||||
if !send(event) {
|
||||
return
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
message := "chat stream returned an error"
|
||||
if len(envelope.Data) > 0 {
|
||||
var response Response
|
||||
if err := json.Unmarshal(envelope.Data, &response); err == nil {
|
||||
message = formatChatStreamResponseError(response)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(string(envelope.Data))
|
||||
if trimmed != "" {
|
||||
message = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
return
|
||||
default:
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2098,8 +2042,8 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var event ChatWatchEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &event); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2110,23 +2054,10 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var event ChatWatchEvent
|
||||
if err := json.Unmarshal(envelope.Data, &event); err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2478,21 +2409,6 @@ func (c *ExperimentalClient) GetChatsByWorkspace(ctx context.Context, workspaceI
|
||||
return result, json.NewDecoder(res.Body).Decode(&result)
|
||||
}
|
||||
|
||||
func formatChatStreamResponseError(response Response) string {
|
||||
message := strings.TrimSpace(response.Message)
|
||||
detail := strings.TrimSpace(response.Detail)
|
||||
switch {
|
||||
case message == "" && detail == "":
|
||||
return "chat stream returned an error"
|
||||
case message == "":
|
||||
return detail
|
||||
case detail == "":
|
||||
return message
|
||||
default:
|
||||
return fmt.Sprintf("%s: %s", message, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// PRInsightsResponse is the response from the PR insights endpoint.
|
||||
type PRInsightsResponse struct {
|
||||
Summary PRInsightsSummary `json:"summary"`
|
||||
|
||||
+2
-2
@@ -145,7 +145,7 @@ export const watchWorkspace = (
|
||||
export const watchChat = (
|
||||
chatId: string,
|
||||
afterMessageId?: number,
|
||||
): OneWayWebSocketApi<TypesGen.ServerSentEvent> => {
|
||||
): OneWayWebSocketApi<TypesGen.ChatStreamEvent[]> => {
|
||||
const params = new URLSearchParams();
|
||||
if (afterMessageId !== undefined && afterMessageId > 0) {
|
||||
params.set("after_id", afterMessageId.toString());
|
||||
@@ -161,7 +161,7 @@ export const watchChat = (
|
||||
});
|
||||
};
|
||||
|
||||
export const watchChats = (): OneWayWebSocket<TypesGen.ServerSentEvent> => {
|
||||
export const watchChats = (): OneWayWebSocket<TypesGen.ChatWatchEvent> => {
|
||||
const searchParams: Record<string, string> = {};
|
||||
const token = API.getSessionToken();
|
||||
if (token) {
|
||||
|
||||
@@ -198,14 +198,6 @@ const buildQueries = (
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Wrap a chat stream event payload in the JSON string format that
|
||||
* OneWayWebSocket expects when receiving a WebSocket message event.
|
||||
* The result is a `ServerSentEvent` of type `"data"` serialised to JSON.
|
||||
*/
|
||||
const wrapSSE = (payload: unknown): string =>
|
||||
JSON.stringify({ type: "data", data: payload });
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Meta
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -856,17 +848,20 @@ export const StreamedSubagentTitle: Story = {
|
||||
"/chats/": [
|
||||
{
|
||||
event: "message",
|
||||
data: wrapSSE({
|
||||
type: "message_part",
|
||||
message_part: {
|
||||
part: {
|
||||
type: "tool-call",
|
||||
tool_call_id: "tool-subagent-stream-1",
|
||||
tool_name: "spawn_agent",
|
||||
args_delta: '{"title":"Streamed Child"',
|
||||
data: JSON.stringify([
|
||||
{
|
||||
type: "message_part",
|
||||
chat_id: CHAT_ID,
|
||||
message_part: {
|
||||
part: {
|
||||
type: "tool-call",
|
||||
tool_call_id: "tool-subagent-stream-1",
|
||||
tool_name: "spawn_agent",
|
||||
args_delta: '{"title":"Streamed Child"',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
] satisfies TypesGen.ChatStreamEvent[]),
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1150,15 +1145,18 @@ export const StreamedReasoning: Story = {
|
||||
"/chats/": [
|
||||
{
|
||||
event: "message",
|
||||
data: wrapSSE({
|
||||
type: "message_part",
|
||||
message_part: {
|
||||
part: {
|
||||
type: "reasoning",
|
||||
text: "Streaming reasoning body",
|
||||
data: JSON.stringify([
|
||||
{
|
||||
type: "message_part",
|
||||
chat_id: CHAT_ID,
|
||||
message_part: {
|
||||
part: {
|
||||
type: "reasoning",
|
||||
text: "Streaming reasoning body",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
] satisfies TypesGen.ChatStreamEvent[]),
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1230,18 +1228,20 @@ export const WithWaitAgentComputerUseVNC: Story = {
|
||||
"/chats/": [
|
||||
{
|
||||
event: "message",
|
||||
data: wrapSSE({
|
||||
type: "message_part",
|
||||
chat_id: CHAT_ID,
|
||||
message_part: {
|
||||
part: {
|
||||
type: "tool-call",
|
||||
tool_call_id: "tool-wait-desktop",
|
||||
tool_name: "wait_agent",
|
||||
args_delta: '{"chat_id":"desktop-child-1"}',
|
||||
data: JSON.stringify([
|
||||
{
|
||||
type: "message_part",
|
||||
chat_id: CHAT_ID,
|
||||
message_part: {
|
||||
part: {
|
||||
type: "tool-call",
|
||||
tool_call_id: "tool-wait-desktop",
|
||||
tool_name: "wait_agent",
|
||||
args_delta: '{"chat_id":"desktop-child-1"}',
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
] satisfies TypesGen.ChatStreamEvent[]),
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -51,7 +51,6 @@ import {
|
||||
chatDetailErrorsEqual,
|
||||
} from "./utils/usageLimitMessage";
|
||||
|
||||
// Type guard for SSE events from the chat list watch endpoint.
|
||||
// Shallow-compare two ChatDiffStatus objects by their meaningful
|
||||
// fields, ignoring refreshed_at/stale_at which change on every poll.
|
||||
function diffStatusEqual(
|
||||
@@ -75,19 +74,6 @@ function diffStatusEqual(
|
||||
);
|
||||
}
|
||||
|
||||
function isChatListSSEEvent(
|
||||
data: unknown,
|
||||
): data is { kind: string; chat: TypesGen.Chat } {
|
||||
if (typeof data !== "object" || data === null) return false;
|
||||
const obj = data as Record<string, unknown>;
|
||||
return (
|
||||
typeof obj.kind === "string" &&
|
||||
typeof obj.chat === "object" &&
|
||||
obj.chat !== null &&
|
||||
"id" in obj.chat
|
||||
);
|
||||
}
|
||||
|
||||
export type { AgentsOutletContext } from "./AgentsPageView";
|
||||
|
||||
const AgentsPage: FC = () => {
|
||||
@@ -495,14 +481,7 @@ const AgentsPage: FC = () => {
|
||||
console.warn("Failed to parse chat watch event:", event.parseError);
|
||||
return;
|
||||
}
|
||||
const sse = event.parsedMessage;
|
||||
if (sse?.type !== "data" || !sse.data) {
|
||||
return;
|
||||
}
|
||||
if (!isChatListSSEEvent(sse.data)) {
|
||||
return;
|
||||
}
|
||||
const chatEvent = sse.data;
|
||||
const chatEvent = event.parsedMessage;
|
||||
const updatedChat = chatEvent.chat;
|
||||
// Read the previous status from the infinite chat list
|
||||
// cache before we write the update below. The per-chat
|
||||
|
||||
@@ -55,7 +55,7 @@ vi.mock("#/api/api", () => ({
|
||||
}));
|
||||
|
||||
type MessageListener = (
|
||||
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
|
||||
payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]>,
|
||||
) => void;
|
||||
type ErrorListener = (payload: Event) => void;
|
||||
type OpenListener = (payload: Event) => void;
|
||||
@@ -67,6 +67,7 @@ type MockSocketHelpers = {
|
||||
emitOpen: () => void;
|
||||
emitData: (event: TypesGen.ChatStreamEvent) => void;
|
||||
emitDataBatch: (events: readonly TypesGen.ChatStreamEvent[]) => void;
|
||||
emitParseError: () => void;
|
||||
emitError: () => void;
|
||||
emitClose: () => void;
|
||||
};
|
||||
@@ -143,26 +144,30 @@ const createMockSocket = (): MockSocket => {
|
||||
removeEventListener,
|
||||
close: vi.fn(),
|
||||
emitData: (event) => {
|
||||
const payload: OneWayMessageEvent<TypesGen.ServerSentEvent> = {
|
||||
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
|
||||
sourceEvent: {} as MessageEvent<string>,
|
||||
parseError: undefined,
|
||||
parsedMessage: {
|
||||
type: "data",
|
||||
data: event,
|
||||
},
|
||||
parsedMessage: [event],
|
||||
};
|
||||
for (const listener of messageListeners) {
|
||||
listener(payload);
|
||||
}
|
||||
},
|
||||
emitDataBatch: (events) => {
|
||||
const payload: OneWayMessageEvent<TypesGen.ServerSentEvent> = {
|
||||
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
|
||||
sourceEvent: {} as MessageEvent<string>,
|
||||
parseError: undefined,
|
||||
parsedMessage: {
|
||||
type: "data",
|
||||
data: events,
|
||||
},
|
||||
parsedMessage: events as TypesGen.ChatStreamEvent[],
|
||||
};
|
||||
for (const listener of messageListeners) {
|
||||
listener(payload);
|
||||
}
|
||||
},
|
||||
emitParseError: () => {
|
||||
const payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]> = {
|
||||
sourceEvent: {} as MessageEvent<string>,
|
||||
parseError: new Error("bad json"),
|
||||
parsedMessage: undefined,
|
||||
};
|
||||
for (const listener of messageListeners) {
|
||||
listener(payload);
|
||||
@@ -4209,3 +4214,215 @@ describe("store/cache desync protection", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("parse errors", () => {
|
||||
it("surfaces parseError as streamError", async () => {
|
||||
immediateAnimationFrame();
|
||||
|
||||
const chatID = "chat-parse-error";
|
||||
const mockSocket = createMockSocket();
|
||||
mockWatchChatReturn(mockSocket);
|
||||
|
||||
const queryClient = createTestQueryClient();
|
||||
const wrapper = ({ children }: PropsWithChildren) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
const setChatErrorReason = vi.fn();
|
||||
const clearChatErrorReason = vi.fn();
|
||||
|
||||
const { result } = renderHook(
|
||||
() => {
|
||||
const { store } = useChatStore({
|
||||
chatID,
|
||||
chatMessages: [],
|
||||
chatRecord: makeChat(chatID),
|
||||
chatMessagesData: {
|
||||
messages: [],
|
||||
queued_messages: [],
|
||||
has_more: false,
|
||||
},
|
||||
chatQueuedMessages: [],
|
||||
setChatErrorReason,
|
||||
clearChatErrorReason,
|
||||
});
|
||||
return {
|
||||
streamError: useChatSelector(store, selectStreamError),
|
||||
chatStatus: useChatSelector(store, selectChatStatus),
|
||||
};
|
||||
},
|
||||
{ wrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, undefined);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
mockSocket.emitParseError();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.streamError).toEqual({
|
||||
kind: "generic",
|
||||
message: "Failed to parse chat stream update.",
|
||||
});
|
||||
});
|
||||
expect(result.current.chatStatus).not.toBe("error");
|
||||
});
|
||||
|
||||
it("does not corrupt in-progress stream state", async () => {
|
||||
immediateAnimationFrame();
|
||||
|
||||
const chatID = "chat-parse-no-corrupt";
|
||||
const existingMessage = makeMessage(chatID, 1, "user", "hello");
|
||||
const mockSocket = createMockSocket();
|
||||
mockWatchChatReturn(mockSocket);
|
||||
|
||||
const queryClient = createTestQueryClient();
|
||||
const wrapper = ({ children }: PropsWithChildren) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
const setChatErrorReason = vi.fn();
|
||||
const clearChatErrorReason = vi.fn();
|
||||
|
||||
const { result } = renderHook(
|
||||
() => {
|
||||
const { store } = useChatStore({
|
||||
chatID,
|
||||
chatMessages: [existingMessage],
|
||||
chatRecord: makeChat(chatID),
|
||||
chatMessagesData: {
|
||||
messages: [existingMessage],
|
||||
queued_messages: [],
|
||||
has_more: false,
|
||||
},
|
||||
chatQueuedMessages: [],
|
||||
setChatErrorReason,
|
||||
clearChatErrorReason,
|
||||
});
|
||||
return {
|
||||
streamState: useChatSelector(store, selectStreamState),
|
||||
streamError: useChatSelector(store, selectStreamError),
|
||||
};
|
||||
},
|
||||
{ wrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
// Build up some stream state first.
|
||||
act(() => {
|
||||
mockSocket.emitData({
|
||||
type: "message_part",
|
||||
chat_id: chatID,
|
||||
message_part: {
|
||||
role: "assistant",
|
||||
part: { type: "text", text: "partial response" },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.streamState?.blocks).toEqual([
|
||||
{ type: "response", text: "partial response" },
|
||||
]);
|
||||
});
|
||||
|
||||
// Fire a parse error and verify the existing stream blocks survive.
|
||||
act(() => {
|
||||
mockSocket.emitParseError();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.streamError).toEqual({
|
||||
kind: "generic",
|
||||
message: "Failed to parse chat stream update.",
|
||||
});
|
||||
});
|
||||
expect(result.current.streamState?.blocks).toEqual([
|
||||
{ type: "response", text: "partial response" },
|
||||
]);
|
||||
});
|
||||
|
||||
it("continues processing after parse error", async () => {
|
||||
immediateAnimationFrame();
|
||||
|
||||
const chatID = "chat-parse-recover";
|
||||
const existingMessage = makeMessage(chatID, 1, "user", "hello");
|
||||
const mockSocket = createMockSocket();
|
||||
mockWatchChatReturn(mockSocket);
|
||||
|
||||
const queryClient = createTestQueryClient();
|
||||
const wrapper = ({ children }: PropsWithChildren) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
const setChatErrorReason = vi.fn();
|
||||
const clearChatErrorReason = vi.fn();
|
||||
|
||||
const { result } = renderHook(
|
||||
() => {
|
||||
const { store } = useChatStore({
|
||||
chatID,
|
||||
chatMessages: [existingMessage],
|
||||
chatRecord: makeChat(chatID),
|
||||
chatMessagesData: {
|
||||
messages: [existingMessage],
|
||||
queued_messages: [],
|
||||
has_more: false,
|
||||
},
|
||||
chatQueuedMessages: [],
|
||||
setChatErrorReason,
|
||||
clearChatErrorReason,
|
||||
});
|
||||
return {
|
||||
streamState: useChatSelector(store, selectStreamState),
|
||||
streamError: useChatSelector(store, selectStreamError),
|
||||
};
|
||||
},
|
||||
{ wrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(watchChat).toHaveBeenCalledWith(chatID, 1);
|
||||
});
|
||||
|
||||
// Trigger a parse error first.
|
||||
act(() => {
|
||||
mockSocket.emitParseError();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.streamError).toEqual({
|
||||
kind: "generic",
|
||||
message: "Failed to parse chat stream update.",
|
||||
});
|
||||
});
|
||||
|
||||
// Send a valid message_part after the parse error.
|
||||
act(() => {
|
||||
mockSocket.emitData({
|
||||
type: "message_part",
|
||||
chat_id: chatID,
|
||||
message_part: {
|
||||
role: "assistant",
|
||||
part: { type: "text", text: "recovered" },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// The stream should process the new part normally.
|
||||
await waitFor(() => {
|
||||
expect(result.current.streamState?.blocks).toEqual([
|
||||
{ type: "response", text: "recovered" },
|
||||
]);
|
||||
});
|
||||
|
||||
// streamError is sticky and is not cleared by valid messages.
|
||||
expect(result.current.streamError).toEqual({
|
||||
kind: "generic",
|
||||
message: "Failed to parse chat stream update.",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,7 +6,6 @@ import type * as TypesGen from "#/api/typesGenerated";
|
||||
import type { OneWayMessageEvent } from "#/utils/OneWayWebSocket";
|
||||
import { createReconnectingWebSocket } from "#/utils/reconnectingWebSocket";
|
||||
import type { ChatDetailError } from "../../utils/usageLimitMessage";
|
||||
import { asNumber, asString } from "../ChatElements/runtimeTypeUtils";
|
||||
import {
|
||||
type ChatStore,
|
||||
type ChatStoreState,
|
||||
@@ -17,50 +16,24 @@ import {
|
||||
} from "./chatStore";
|
||||
import type { RetryState } from "./types";
|
||||
|
||||
const isChatStreamEvent = (data: unknown): data is TypesGen.ChatStreamEvent =>
|
||||
typeof data === "object" &&
|
||||
data !== null &&
|
||||
"type" in data &&
|
||||
typeof (data as Record<string, unknown>).type === "string";
|
||||
|
||||
const isChatStreamEventArray = (
|
||||
data: unknown,
|
||||
): data is TypesGen.ChatStreamEvent[] =>
|
||||
Array.isArray(data) && data.every(isChatStreamEvent);
|
||||
|
||||
const toChatStreamEvents = (data: unknown): TypesGen.ChatStreamEvent[] => {
|
||||
if (isChatStreamEvent(data)) {
|
||||
return [data];
|
||||
}
|
||||
if (isChatStreamEventArray(data)) {
|
||||
return data;
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const normalizeChatDetailError = (
|
||||
error: TypesGen.ChatStreamError | Record<string, unknown> | undefined,
|
||||
error: TypesGen.ChatStreamError | undefined,
|
||||
): ChatDetailError => ({
|
||||
message: asString(error?.message).trim() || "Chat processing failed.",
|
||||
kind: asString(error?.kind).trim() || "generic",
|
||||
provider: asString(error?.provider).trim() || undefined,
|
||||
retryable:
|
||||
typeof error?.retryable === "boolean" ? error.retryable : undefined,
|
||||
statusCode: asNumber(error?.status_code),
|
||||
message: error?.message.trim() || "Chat processing failed.",
|
||||
kind: error?.kind?.trim() || "generic",
|
||||
provider: error?.provider?.trim() || undefined,
|
||||
retryable: error?.retryable,
|
||||
statusCode: error?.status_code,
|
||||
});
|
||||
|
||||
const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => {
|
||||
const delayMs = asNumber(retry.delay_ms);
|
||||
const retryingAt = asString(retry.retrying_at).trim() || undefined;
|
||||
return {
|
||||
attempt: Math.max(1, asNumber(retry.attempt) ?? 1),
|
||||
error: asString(retry.error).trim() || "Retrying request shortly.",
|
||||
kind: asString(retry.kind).trim() || "generic",
|
||||
provider: asString(retry.provider).trim() || undefined,
|
||||
...(delayMs !== undefined ? { delayMs } : {}),
|
||||
...(retryingAt ? { retryingAt } : {}),
|
||||
};
|
||||
};
|
||||
const normalizeRetryState = (retry: TypesGen.ChatStreamRetry): RetryState => ({
|
||||
attempt: Math.max(1, retry.attempt),
|
||||
error: retry.error.trim() || "Retrying request shortly.",
|
||||
kind: retry.kind?.trim() || "generic",
|
||||
provider: retry.provider?.trim() || undefined,
|
||||
delayMs: retry.delay_ms,
|
||||
retryingAt: retry.retrying_at.trim() || undefined,
|
||||
});
|
||||
|
||||
const shouldSurfaceReconnectState = (state: ChatStoreState): boolean =>
|
||||
state.streamError === null &&
|
||||
@@ -419,7 +392,7 @@ export const useChatStore = (
|
||||
};
|
||||
|
||||
const handleMessage = (
|
||||
payload: OneWayMessageEvent<TypesGen.ServerSentEvent>,
|
||||
payload: OneWayMessageEvent<TypesGen.ChatStreamEvent[]>,
|
||||
) => {
|
||||
if (disposed) {
|
||||
return;
|
||||
@@ -431,11 +404,8 @@ export const useChatStore = (
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (payload.parsedMessage.type !== "data") {
|
||||
return;
|
||||
}
|
||||
|
||||
const streamEvents = toChatStreamEvents(payload.parsedMessage.data);
|
||||
const streamEvents = payload.parsedMessage;
|
||||
if (streamEvents.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user