diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index 23682f3a13..3aa10f315a 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -62,7 +62,7 @@ type Server struct { workerID uuid.UUID logger slog.Logger - remotePartsProvider RemotePartsProvider + subscribeFn SubscribeFn agentConnFn AgentConnFunc createWorkspaceFn chattool.CreateWorkspaceFn @@ -93,24 +93,41 @@ type cachedInstruction struct { // AgentConnFunc provides access to workspace agent connections. type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) -// ReplicaAddressResolver maps a replica ID to its relay address. -type ReplicaAddressResolver func(context.Context, uuid.UUID) (string, bool) - -// RemotePartsProvider returns a snapshot and live stream of message_part -// events from the replica that is running the chat. Called when the chat -// is actively running on a different replica. Nil in AGPL single-replica -// deployments. -type RemotePartsProvider func( +// SubscribeFn replaces the default local-only subscription with a +// multi-replica-aware implementation that merges pubsub notifications, +// remote relay streams, and local parts into a single event channel. +// When set, Subscribe delegates the event-merge goroutine to this +// function instead of using simple local forwarding. +// +// Parameters: +// - ctx: subscription lifetime context (canceled on unsubscribe). +// - params: all state needed to build the merged stream. +// +// Returns the merged event channel and a cleanup function. +// Set by enterprise for HA deployments. Nil in AGPL single-replica. +type SubscribeFn func( ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, -) ( - snapshot []codersdk.ChatStreamEvent, - parts <-chan codersdk.ChatStreamEvent, - cancel func(), - err error, -) + params SubscribeFnParams, +) (<-chan codersdk.ChatStreamEvent, func()) + +// StatusNotification informs the enterprise relay manager of chat +// status changes so it can open or close relay connections. +type StatusNotification struct { + Status database.ChatStatus + WorkerID uuid.UUID +} + +// SubscribeFnParams carries the state that the enterprise +// SubscribeFn implementation needs from the OSS Subscribe preamble. +type SubscribeFnParams struct { + ChatID uuid.UUID + Chat database.Chat + WorkerID uuid.UUID + StatusNotifications <-chan StatusNotification + RequestHeader http.Header + DB database.Store + Logger slog.Logger +} type chatStreamState struct { buffer []codersdk.ChatStreamEvent @@ -129,6 +146,12 @@ var ( ErrEditedMessageNotFound = xerrors.New("edited message not found") // ErrEditedMessageNotUser indicates a non-user message edit attempt. ErrEditedMessageNotUser = xerrors.New("only user messages can be edited") + + // errChatTakenByOtherWorker is a sentinel used inside the + // processChat cleanup transaction to signal that another + // worker acquired the chat, so all post-TX side effects + // (status publish, pubsub, web push) must be skipped. + errChatTakenByOtherWorker = xerrors.New("chat acquired by another worker") ) // CreateOptions controls chat creation in the shared chat mutation path. @@ -719,14 +742,31 @@ func setChatPendingWithStore( } func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) { - updatedChat, err := p.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chatID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: sql.NullString{}, - }) + var updatedChat database.Chat + err := p.db.InTx(func(tx database.Store) error { + locked, lockErr := tx.GetChatByIDForUpdate(ctx, chatID) + if lockErr != nil { + return xerrors.Errorf("lock chat for waiting: %w", lockErr) + } + // If the chat has already transitioned to pending (e.g. + // SendMessage with interrupt behavior), don't overwrite + // it — the pending status takes priority so the new + // message gets processed. + if locked.Status == database.ChatStatusPending { + updatedChat = locked + return nil + } + var updateErr error + updatedChat, updateErr = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chatID, + Status: database.ChatStatusWaiting, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + return updateErr + }, nil) if err != nil { return database.Chat{}, err } @@ -807,7 +847,7 @@ type Config struct { Logger slog.Logger Database database.Store ReplicaID uuid.UUID - RemotePartsProvider RemotePartsProvider + SubscribeFn SubscribeFn PendingChatAcquireInterval time.Duration InFlightChatStaleAfter time.Duration AgentConn AgentConnFunc @@ -844,7 +884,7 @@ func New(cfg Config) *Server { db: cfg.Database, workerID: workerID, logger: cfg.Logger.Named("chat-processor"), - remotePartsProvider: cfg.RemotePartsProvider, + subscribeFn: cfg.SubscribeFn, agentConnFn: cfg.AgentConn, createWorkspaceFn: cfg.CreateWorkspace, pubsub: cfg.Pubsub, @@ -954,10 +994,12 @@ func (p *Server) subscribeToStream(chatID uuid.UUID) ( p.streamMu.Lock() state, ok := p.chatStreams[chatID] if ok { - if subscriber, exists := state.subscribers[id]; exists { - delete(state.subscribers, id) - close(subscriber) - } + // Remove the subscriber but do not close the channel. + // publishToStream copies subscriber references under + // streamMu then sends outside the lock; closing here + // races with that send and can panic. The channel + // becomes unreachable once removed and will be GC'd. + delete(state.subscribers, id) p.cleanupStreamIfIdleLocked(chatID, state) } p.streamMu.Unlock() @@ -1005,7 +1047,7 @@ func (p *Server) Subscribe( // Subscribe to local stream for message_parts (ephemeral). localSnapshot, localParts, localCancel := p.subscribeToStream(chatID) - // Build initial snapshot synchronously + // Build initial snapshot synchronously. initialSnapshot := make([]codersdk.ChatStreamEvent, 0) // Add local message_parts to snapshot for _, event := range localSnapshot { @@ -1033,7 +1075,7 @@ func (p *Server) Subscribe( } } - // Load initial queue + // Load initial queue. queued, err := p.db.GetChatQueuedMessages(ctx, chatID) if err == nil && len(queued) > 0 { initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ @@ -1043,24 +1085,8 @@ func (p *Server) Subscribe( }) } - // Get initial chat state to determine if we need a relay + // Get initial chat state to determine if we need a relay. chat, err := p.db.GetChatByID(ctx, chatID) - var relayCancel func() - var relayParts <-chan codersdk.ChatStreamEvent - if err == nil && chat.Status == database.ChatStatusRunning && chat.WorkerID.Valid && chat.WorkerID.UUID != p.workerID && p.remotePartsProvider != nil { - // Open relay for initial snapshot - snapshot, parts, cancel, err := p.remotePartsProvider(ctx, chatID, chat.WorkerID.UUID, requestHeader) - if err == nil { - relayCancel = cancel - relayParts = parts - // Add relay message_parts to snapshot - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - initialSnapshot = append(initialSnapshot, event) - } - } - } - } // Include the current chat status in the snapshot so the // frontend can gate message_part processing correctly from @@ -1079,119 +1105,38 @@ func (p *Server) Subscribe( initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...) } - // Track the last message ID we've seen for DB queries - var lastMessageID int64 + // Track the last message ID we've seen for DB queries. + // Initialize from afterMessageID so that when the caller passes + // afterMessageID > 0 but no new messages exist yet, the first + // pubsub catch-up doesn't re-fetch already-seen messages. + lastMessageID := afterMessageID if len(messages) > 0 { lastMessageID = messages[len(messages)-1].ID } - // Merge all event sources + // Merge all event sources. mergedCtx, mergedCancel := context.WithCancel(ctx) mergedEvents := make(chan codersdk.ChatStreamEvent, 128) + var allCancels []func() allCancels = append(allCancels, localCancel) - if relayCancel != nil { - allCancels = append(allCancels, relayCancel) - } - // Channel for async relay establishment. - type relayResult struct { - parts <-chan codersdk.ChatStreamEvent - cancel func() - } - relayReadyCh := make(chan relayResult, 1) - - // Reconnect timer state. - var reconnectTimer *time.Timer - var reconnectCh <-chan time.Time - - // Helper to close relay and stop any pending reconnect timer. - closeRelay := func() { - if relayCancel != nil { - relayCancel() - relayCancel = nil - } - relayParts = nil - if reconnectTimer != nil { - reconnectTimer.Stop() - reconnectTimer = nil - reconnectCh = nil - } - } - - // openRelayAsync dials the remote replica in a background - // goroutine and delivers the result on relayReadyCh so the - // main select loop is never blocked by network I/O. - openRelayAsync := func(workerID uuid.UUID) { - if p.remotePartsProvider == nil { - return - } - closeRelay() - go func() { - snapshot, parts, cancel, err := p.remotePartsProvider(mergedCtx, chatID, workerID, requestHeader) - if err != nil { - p.logger.Warn(mergedCtx, "failed to open relay for message parts", - slog.F("chat_id", chatID), - slog.F("worker_id", workerID), - slog.Error(err), - ) - return - } - // Wrap the relay channel so snapshot parts are - // delivered through the same channel as live parts. - wrappedParts := make(chan codersdk.ChatStreamEvent, 128) - go func() { - defer close(wrappedParts) - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - select { - case wrappedParts <- event: - case <-mergedCtx.Done(): - cancel() - return - } - } - } - for event := range parts { - select { - case wrappedParts <- event: - case <-mergedCtx.Done(): - return - } - } - }() - select { - case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel}: - case <-mergedCtx.Done(): - cancel() - } - }() - } - - // scheduleRelayReconnect arms a short timer so the select - // loop can re-check chat status and reopen the relay without - // spinning in a tight loop. - scheduleRelayReconnect := func() { - if p.remotePartsProvider == nil { - return - } - if reconnectTimer != nil { - reconnectTimer.Stop() - } - reconnectTimer = time.NewTimer(500 * time.Millisecond) - reconnectCh = reconnectTimer.C - } - - //nolint:nestif + // Subscribe to pubsub for durable events (status, messages, + // queue updates, errors). When pubsub is nil (e.g. in-memory + // single-instance) we skip this and deliver all local events. + var notifications <-chan coderdpubsub.ChatStreamNotifyMessage + var errCh <-chan error if p.pubsub != nil { - notifications := make(chan coderdpubsub.ChatStreamNotifyMessage, 10) - errCh := make(chan error, 1) + notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10) + errNotifyCh := make(chan error, 1) + notifications = notifyCh + errCh = errNotifyCh - listener := func(_ context.Context, message []byte, err error) { - if err != nil { + listener := func(_ context.Context, message []byte, listenErr error) { + if listenErr != nil { select { case <-mergedCtx.Done(): - case errCh <- err: + case errNotifyCh <- listenErr: } return } @@ -1199,187 +1144,214 @@ func (p *Server) Subscribe( if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil { select { case <-mergedCtx.Done(): - case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr): + case errNotifyCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr): } return } select { case <-mergedCtx.Done(): - case notifications <- notify: + case notifyCh <- notify: } } - // Subscribe to pubsub for durable events - if pubsubCancel, err := p.pubsub.SubscribeWithErr( + if pubsubCancel, pubsubErr := p.pubsub.SubscribeWithErr( coderdpubsub.ChatStreamNotifyChannel(chatID), listener, - ); err == nil { + ); pubsubErr == nil { allCancels = append(allCancels, pubsubCancel) } else { - p.logger.Warn(mergedCtx, "failed to subscribe to chat stream notifications", + p.logger.Warn(ctx, "failed to subscribe to chat stream notifications", slog.F("chat_id", chatID), - slog.Error(err), + slog.Error(pubsubErr), ) } + } - // Handle pubsub notifications in a goroutine - go func() { - defer close(mergedEvents) - defer closeRelay() + // When an enterprise SubscribeFn is provided and the chat + // lookup succeeded, call it to get relay events (message_parts + // from remote replicas). OSS now owns pubsub subscription, + // message catch-up, queue updates, and status forwarding; + // enterprise only manages relay dialing. + var relayEvents <-chan codersdk.ChatStreamEvent + var relayCleanup func() + var statusNotifications chan StatusNotification + if p.subscribeFn != nil && err == nil { + statusNotifications = make(chan StatusNotification, 10) + var relayEvCh <-chan codersdk.ChatStreamEvent + relayEvCh, relayCleanup = p.subscribeFn(mergedCtx, SubscribeFnParams{ + ChatID: chatID, + Chat: chat, + WorkerID: p.workerID, + StatusNotifications: statusNotifications, + RequestHeader: requestHeader, + DB: p.db, + Logger: p.logger, + }) + relayEvents = relayEvCh + } - for { - relayPartsCh := relayParts + hasPubsub := false + if p.pubsub != nil { + // hasPubsub is only true when we actually subscribed + // successfully above (allCancels will contain the pubsub + // cancel func in that case). + hasPubsub = len(allCancels) > 1 + } + + //nolint:nestif + go func() { + defer close(mergedEvents) + if statusNotifications != nil { + defer close(statusNotifications) + } + for { + select { + case <-mergedCtx.Done(): + return + case psErr := <-errCh: + p.logger.Error(mergedCtx, "chat stream pubsub error", + slog.F("chat_id", chatID), + slog.Error(psErr), + ) select { + case mergedEvents <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeError, + ChatID: chatID, + Error: &codersdk.ChatStreamError{ + Message: psErr.Error(), + }, + }: case <-mergedCtx.Done(): - return - case err := <-errCh: - p.logger.Error(mergedCtx, "chat stream pubsub error", - slog.F("chat_id", chatID), - slog.Error(err), - ) - mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatStreamError{ - Message: err.Error(), - }, - } - return - case result := <-relayReadyCh: - // An async relay dial completed; swap in the - // new relay channel. - closeRelay() - relayParts = result.parts - relayCancel = result.cancel - case <-reconnectCh: - reconnectCh = nil - // Re-check whether the chat is still running - // on a remote worker before reconnecting. - currentChat, chatErr := p.db.GetChatByID(mergedCtx, chatID) - if chatErr == nil && currentChat.Status == database.ChatStatusRunning && - currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != p.workerID { - openRelayAsync(currentChat.WorkerID.UUID) - } - case notify := <-notifications: - // Handle different notification types - if notify.AfterMessageID > 0 { - // Read only new messages from DB. - messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: lastMessageID, - }) - if err == nil { - for _, msg := range messages { - sdkMsg := db2sdk.ChatMessage(msg) - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMsg, - }: - } - lastMessageID = msg.ID - } - } - } - if notify.Status != "" { - status := database.ChatStatus(notify.Status) - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeStatus, - ChatID: chatID, - Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)}, - }: - } - // Manage relay lifecycle based on status. - if status == database.ChatStatusRunning && notify.WorkerID != "" { - workerID, err := uuid.Parse(notify.WorkerID) - if err == nil && workerID != p.workerID { - openRelayAsync(workerID) - } else if workerID == p.workerID { - closeRelay() - } - } else { - closeRelay() - } - } - if notify.Error != "" { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatStreamError{ - Message: notify.Error, - }, - }: - } - } - if notify.QueueUpdate { - queued, err := p.db.GetChatQueuedMessages(mergedCtx, chatID) - if err == nil { + } + return + case notify := <-notifications: + if notify.AfterMessageID > 0 { + newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: lastMessageID, + }) + if msgErr != nil { + p.logger.Warn(mergedCtx, "failed to get chat messages after pubsub notification", + slog.F("chat_id", chatID), + slog.Error(msgErr), + ) + } else { + for _, msg := range newMessages { + sdkMsg := db2sdk.ChatMessage(msg) select { case <-mergedCtx.Done(): return case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeQueueUpdate, - ChatID: chatID, - QueuedMessages: db2sdk.ChatQueuedMessages(queued), + Type: codersdk.ChatStreamEventTypeMessage, + ChatID: chatID, + Message: &sdkMsg, }: } - } - } - case event, ok := <-localParts: - if !ok { - // Local parts channel closed, but continue with pubsub - continue - } - // Only forward message_part events from local (durable events come via pubsub) - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: - } - } - case event, ok := <-relayPartsCh: - if !ok { - relayParts = nil - // Schedule reconnection instead of giving up. - scheduleRelayReconnect() - continue - } - // Only forward message_part events from relay (durable events come via pubsub) - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: + lastMessageID = msg.ID } } } - } - }() - } else { - // No pubsub, just merge local parts. - // localSnapshot was already included in initialSnapshot, - // so only forward new events here. - go func() { - defer close(mergedEvents) - for event := range localParts { + if notify.Status != "" { + status := database.ChatStatus(notify.Status) + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeStatus, + ChatID: chatID, + Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)}, + }: + } + // Notify enterprise relay manager if present. + if statusNotifications != nil { + workerID := uuid.Nil + if notify.WorkerID != "" { + if parsed, parseErr := uuid.Parse(notify.WorkerID); parseErr == nil { + workerID = parsed + } + } + select { + case statusNotifications <- StatusNotification{Status: status, WorkerID: workerID}: + case <-mergedCtx.Done(): + return + } + } + } + if notify.Error != "" { + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeError, + ChatID: chatID, + Error: &codersdk.ChatStreamError{ + Message: notify.Error, + }, + }: + } + } + if notify.QueueUpdate { + queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID) + if queueErr != nil { + p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification", + slog.F("chat_id", chatID), + slog.Error(queueErr), + ) + } else { + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeQueueUpdate, + ChatID: chatID, + QueuedMessages: db2sdk.ChatQueuedMessages(queuedMsgs), + }: + } + } + } + case event, ok := <-localParts: + if !ok { + localParts = nil + // Local parts channel closed. If pubsub is + // active we continue with pubsub-driven events. + // Otherwise terminate. + if !hasPubsub { + return + } + continue + } + if hasPubsub { + // Only forward message_part events from local + // (durable events come via pubsub). + if event.Type == codersdk.ChatStreamEventTypeMessagePart { + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- event: + } + } + } else { + // No pubsub: forward all event types. + select { + case <-mergedCtx.Done(): + return + case mergedEvents <- event: + } + } + case event, ok := <-relayEvents: + if !ok { + relayEvents = nil + continue + } select { case <-mergedCtx.Done(): return case mergedEvents <- event: } } - }() - } + } + }() + cancel := func() { mergedCancel() for _, cancelFn := range allCancels { @@ -1387,11 +1359,10 @@ func (p *Server) Subscribe( cancelFn() } } - if reconnectTimer != nil { - reconnectTimer.Stop() + if relayCleanup != nil { + relayCleanup() } } - return initialSnapshot, mergedEvents, cancel, true } @@ -1733,6 +1704,15 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { return xerrors.Errorf("lock chat for release: %w", lockErr) } + // If another worker has already acquired this chat, + // bail out — we must not overwrite their running + // status or publish spurious events. + if latestChat.Status == database.ChatStatusRunning && + latestChat.WorkerID.Valid && + latestChat.WorkerID.UUID != p.workerID { + return errChatTakenByOtherWorker + } + // If someone else already set the chat to pending (e.g. // the promote endpoint), don't overwrite it — just clear // the worker and let the processor pick it back up. @@ -1787,6 +1767,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { }) return updateErr }, nil) + if errors.Is(err, errChatTakenByOtherWorker) { + // Another worker owns this chat now — skip all + // post-TX side effects (status publish, pubsub, + // web push) to avoid overwriting their state. + return + } if err != nil { logger.Error(cleanupCtx, "failed to release chat", slog.Error(err)) } diff --git a/coderd/chatd/chatd_test.go b/coderd/chatd/chatd_test.go index b495aa7b8c..5b4c76af0e 100644 --- a/coderd/chatd/chatd_test.go +++ b/coderd/chatd/chatd_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strings" "sync" "sync/atomic" @@ -28,7 +27,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" - coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" @@ -1133,30 +1131,6 @@ func newTestServer( return server } -func newTestServerWithRelay( - t *testing.T, - db database.Store, - ps dbpubsub.Pubsub, - replicaID uuid.UUID, - provider chatd.RemotePartsProvider, -) *chatd.Server { - t.Helper() - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: replicaID, - Pubsub: ps, - RemotePartsProvider: provider, - PendingChatAcquireInterval: testutil.WaitSuperLong, - }) - t.Cleanup(func() { - require.NoError(t, server.Close()) - }) - return server -} - func seedChatDependencies( ctx context.Context, t *testing.T, @@ -1213,293 +1187,6 @@ func setOpenAIProviderBaseURL( require.NoError(t, err) } -func TestSubscribeRelayReconnectsOnDrop(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 10) - if call == 1 { - // First relay: send a part then close to simulate a drop. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first-relay"}, - }, - } - close(ch) - } else { - // Second relay: send a different part, keep open. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "second-relay"}, - }, - } - // Don't close — keep alive so the subscriber stays connected. - } - return nil, ch, func() {}, nil - } - - subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider) - - ctx := testutil.Context(t, testutil.WaitLong) - user, model := seedChatDependencies(ctx, t, db) - - // Create a chat and mark it as running on a remote worker. - chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{ - OwnerID: user.ID, - Title: "relay-reconnect", - ModelConfigID: model.ID, - InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, - }) - require.NoError(t, err) - - chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, - }) - require.NoError(t, err) - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Should get the first relay part. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "first-relay" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // After the first relay closes, a reconnection should happen and - // deliver the second relay part. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "second-relay" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - require.GreaterOrEqual(t, int(callCount.Load()), 2) -} - -func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - dialStarted := make(chan struct{}) - dialContinue := make(chan struct{}) - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - // Signal that the dial has started, then block until released. - select { - case <-dialStarted: - default: - close(dialStarted) - } - select { - case <-dialContinue: - case <-ctx.Done(): - return nil, nil, nil, ctx.Err() - } - ch := make(chan codersdk.ChatStreamEvent, 10) - return nil, ch, func() {}, nil - } - - subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider) - - ctx := testutil.Context(t, testutil.WaitLong) - user, model := seedChatDependencies(ctx, t, db) - - // Create a chat in pending status. - chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{ - OwnerID: user.ID, - Title: "relay-async-nonblock", - ModelConfigID: model.ID, - InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, - }) - require.NoError(t, err) - - // Subscribe before the chat is marked running so the relay opens - // via pubsub notification (openRelayAsync path). - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Now mark the chat as running on a remote worker. This publishes - // a status notification which triggers openRelayAsync on the - // subscriber. - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerID.String(), - } - payload, err := json.Marshal(notify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) - require.NoError(t, err) - - // Wait for the relay dial to actually start (blocking in the - // provider). - select { - case <-dialStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for relay dial to start") - } - - // While the relay is still dialing (provider is blocked), publish - // another status change. If openRelayAsync blocked the select loop - // this event would never arrive. - statusNotify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusWaiting), - } - statusPayload, err := json.Marshal(statusNotify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload) - require.NoError(t, err) - - // The waiting status event should arrive promptly despite the - // relay still dialing. - require.Eventually(t, func() bool { - select { - case event := <-events: - return event.Type == codersdk.ChatStreamEventTypeStatus && - event.Status != nil && - event.Status.Status == codersdk.ChatStatusWaiting - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast) - - // Unblock the relay dial so the test can clean up. - close(dialContinue) -} - -func TestSubscribeRelaySnapshotDelivered(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - // Return a non-empty snapshot with two parts. - snapshot := []codersdk.ChatStreamEvent{ - { - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-one"}, - }, - }, - { - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-two"}, - }, - }, - } - ch := make(chan codersdk.ChatStreamEvent, 10) - // Also send a live part after the snapshot. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "live-part"}, - }, - } - return snapshot, ch, func() {}, nil - } - - subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider) - - ctx := testutil.Context(t, testutil.WaitLong) - user, model := seedChatDependencies(ctx, t, db) - - // Create a chat already running on a remote worker. - chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{ - OwnerID: user.ID, - Title: "relay-snapshot", - ModelConfigID: model.ID, - InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, - }) - require.NoError(t, err) - - _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, - }) - require.NoError(t, err) - - initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // The initial snapshot should contain the two relay snapshot parts. - var snapshotTexts []string - for _, event := range initialSnapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart && event.MessagePart != nil { - snapshotTexts = append(snapshotTexts, event.MessagePart.Part.Text) - } - } - require.Contains(t, snapshotTexts, "snap-one") - require.Contains(t, snapshotTexts, "snap-two") - - // The live part should arrive on the events channel. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "live-part" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) -} - func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) { t.Parallel() diff --git a/coderd/chats_test.go b/coderd/chats_test.go index 425f2a6297..911cc6133c 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -1694,7 +1694,7 @@ func TestStreamChat(t *testing.T) { }) require.NoError(t, err) - events, closer, err := client.StreamChat(ctx, chat.ID) + events, closer, err := client.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer closer.Close() diff --git a/coderd/coderd.go b/coderd/coderd.go index afb99c6f19..f4a8e7c9eb 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -239,9 +239,9 @@ type Options struct { SSHConfig codersdk.SSHConfigResponse HTTPClient *http.Client - // ChatRemotePartsProvider provides cross-replica message_part streaming. + // ChatSubscribeFn provides cross-replica subscription merging. // Set by enterprise for HA deployments. Nil in AGPL single-replica. - ChatRemotePartsProvider chatd.RemotePartsProvider + ChatSubscribeFn chatd.SubscribeFn UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) StatsBatcher workspacestats.Batcher @@ -760,15 +760,15 @@ func New(options *Options) *API { api.agentProvider = stn api.chatDaemon = chatd.New(chatd.Config{ - Logger: options.Logger.Named("chats"), - Database: options.Database, - ReplicaID: api.ID, - RemotePartsProvider: options.ChatRemotePartsProvider, - ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues), - AgentConn: api.agentProvider.AgentConn, - CreateWorkspace: api.chatCreateWorkspace, - Pubsub: options.Pubsub, - WebpushDispatcher: options.WebPushDispatcher, + Logger: options.Logger.Named("chats"), + Database: options.Database, + ReplicaID: api.ID, + SubscribeFn: options.ChatSubscribeFn, + ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues), + AgentConn: api.agentProvider.AgentConn, + CreateWorkspace: api.chatCreateWorkspace, + Pubsub: options.Pubsub, + WebpushDispatcher: options.WebPushDispatcher, }) if options.DeploymentValues.Prometheus.Enable { options.PrometheusRegistry.MustRegister(stn) diff --git a/codersdk/chats.go b/codersdk/chats.go index df0b9b5d24..bbcf48d94d 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -670,15 +670,29 @@ func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, e return chat, json.NewDecoder(res.Body).Decode(&chat) } +// StreamChatOptions are optional parameters for StreamChat. +type StreamChatOptions struct { + // AfterID limits the initial snapshot to messages created + // after the given ID. This is useful for relay connections + // that only need live message_part events and can skip the + // full message history. + AfterID *int64 +} + // StreamChat streams chat updates in real time. // // The returned channel includes initial snapshot events first, followed by // live updates. Callers must close the returned io.Closer to release the // websocket connection when done. -func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID) (<-chan ChatStreamEvent, io.Closer, error) { +func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID, opts *StreamChatOptions) (<-chan ChatStreamEvent, io.Closer, error) { + path := fmt.Sprintf("/api/experimental/chats/%s/stream", chatID) + if opts != nil && opts.AfterID != nil { + path += fmt.Sprintf("?after_id=%d", *opts.AfterID) + } + conn, err := c.Dial( ctx, - fmt.Sprintf("/api/experimental/chats/%s/stream", chatID), + path, &websocket.DialOptions{CompressionMode: websocket.CompressionDisabled}, ) if err != nil { diff --git a/enterprise/coderd/chatd/chatd.go b/enterprise/coderd/chatd/chatd.go new file mode 100644 index 0000000000..673f4969c5 --- /dev/null +++ b/enterprise/coderd/chatd/chatd.go @@ -0,0 +1,575 @@ +package chatd + +import ( + "context" + "math" + "net/http" + "net/url" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + osschatd "github.com/coder/coder/v2/coderd/chatd" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" + "github.com/coder/websocket" +) + +// RelaySourceHeader marks replica-relayed stream requests. +const RelaySourceHeader = "X-Coder-Relay-Source-Replica" + +const ( + authorizationHeader = "Authorization" + cookieHeader = "Cookie" +) + +// MultiReplicaSubscribeConfig holds the dependencies for multi-replica chat +// subscription. ReplicaIDFn is called lazily because the +// replica ID may not be known at construction time. +// +// DialerFn, when set, overrides the default WebSocket relay +// dialer. This is used in tests to inject mock relay behavior +// without requiring real HTTP servers. +type MultiReplicaSubscribeConfig struct { + ResolveReplicaAddress func(context.Context, uuid.UUID) (string, bool) + ReplicaHTTPClient *http.Client + ReplicaIDFn func() uuid.UUID + DialerFn func( + ctx context.Context, + chatID uuid.UUID, + workerID uuid.UUID, + requestHeader http.Header, + ) ( + snapshot []codersdk.ChatStreamEvent, + parts <-chan codersdk.ChatStreamEvent, + cancel func(), + err error, + ) + // Clock is used for creating timers. In production use + // quartz.NewReal(); in tests use quartz.NewMock(t) to + // control reconnect timing deterministically. + Clock quartz.Clock +} + +// dial returns the dialer function to use for relay connections. +// If DialerFn is set (e.g. in tests), it takes precedence. +// Otherwise, dialRelay is used with the real MultiReplicaSubscribeConfig dependencies. +// Returns nil when no relay capability is configured. +func (c MultiReplicaSubscribeConfig) dial() func( + ctx context.Context, + chatID uuid.UUID, + workerID uuid.UUID, + requestHeader http.Header, +) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + error, +) { + if c.DialerFn != nil { + return c.DialerFn + } + if c.ResolveReplicaAddress == nil { + return nil + } + return func( + ctx context.Context, + chatID uuid.UUID, + workerID uuid.UUID, + requestHeader http.Header, + ) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + error, + ) { + return dialRelay(ctx, chatID, workerID, requestHeader, c, c.clock()) + } +} + +// clock returns the quartz.Clock to use. Defaults to a real clock +// when not set. +func (c MultiReplicaSubscribeConfig) clock() quartz.Clock { + if c.Clock != nil { + return c.Clock + } + return quartz.NewReal() +} + +// NewMultiReplicaSubscribeFn returns a SubscribeFn that manages +// relay connections to remote replicas and returns relay +// message_part events only. OSS handles pubsub subscription, +// message catch-up, queue updates, status forwarding, and local +// parts merging. +// +//nolint:gocognit // Complexity is inherent to the multi-source merge loop. +func NewMultiReplicaSubscribeFn( + cfg MultiReplicaSubscribeConfig, +) osschatd.SubscribeFn { + return func(ctx context.Context, params osschatd.SubscribeFnParams) (<-chan codersdk.ChatStreamEvent, func()) { + chatID := params.ChatID + requestHeader := params.RequestHeader + logger := params.Logger + + var relayCancel func() + var relayParts <-chan codersdk.ChatStreamEvent + + // If the chat is currently running on a different worker + // and we have a remote parts provider, open an initial + // relay synchronously so the caller gets in-flight + // message_part events right away. + var initialRelaySnapshot []codersdk.ChatStreamEvent + if params.Chat.Status == database.ChatStatusRunning && + params.Chat.WorkerID.Valid && + params.Chat.WorkerID.UUID != params.WorkerID && + cfg.dial() != nil { + snapshot, parts, cancel, err := cfg.dial()(ctx, chatID, params.Chat.WorkerID.UUID, requestHeader) + if err == nil { + relayCancel = cancel + relayParts = parts + // Collect relay message_parts to forward at the + // start of the merge goroutine. + for _, event := range snapshot { + if event.Type == codersdk.ChatStreamEventTypeMessagePart { + initialRelaySnapshot = append(initialRelaySnapshot, event) + } + } + } else { + logger.Warn(ctx, "failed to open initial relay for chat stream", + slog.F("chat_id", chatID), + slog.Error(err), + ) + } + } + + // Merge all event sources. + mergedEvents := make(chan codersdk.ChatStreamEvent, 128) + var allCancels []func() + if relayCancel != nil { + allCancels = append(allCancels, relayCancel) + } + + // Channel for async relay establishment. + type relayResult struct { + parts <-chan codersdk.ChatStreamEvent + cancel func() + workerID uuid.UUID // the worker this dial targeted + } + relayReadyCh := make(chan relayResult, 1) + + // Per-dial context so in-flight dials can be canceled when + // a new dial is initiated or the relay is closed. + var dialCancel context.CancelFunc + + // expectedWorkerID tracks which replica we expect the next + // relay result to target. Stale results are discarded. + var expectedWorkerID uuid.UUID + + // Reconnect timer state. + var reconnectTimer *quartz.Timer + var reconnectCh <-chan time.Time + + // Helper to close relay and stop any pending reconnect + // timer. + closeRelay := func() { + // Cancel any in-flight dial goroutine first. + if dialCancel != nil { + dialCancel() + dialCancel = nil + } + // Drain any buffered relay result from a canceled + // dial. + select { + case result := <-relayReadyCh: + if result.cancel != nil { + result.cancel() + } + default: + } + expectedWorkerID = uuid.Nil + if relayCancel != nil { + relayCancel() + relayCancel = nil + } + relayParts = nil + if reconnectTimer != nil { + reconnectTimer.Stop() + reconnectTimer = nil + reconnectCh = nil + } + } + + // openRelayAsync dials the remote replica in a background + // goroutine and delivers the result on relayReadyCh so the + // main select loop is never blocked by network I/O. + openRelayAsync := func(workerID uuid.UUID) { + if cfg.dial() == nil { + return + } + closeRelay() + // Create a per-dial context so this goroutine is + // canceled if closeRelay() or openRelayAsync() is + // called again before the dial completes. + var dialCtx context.Context + dialCtx, dialCancel = context.WithCancel(ctx) + expectedWorkerID = workerID + go func() { + snapshot, parts, cancel, err := cfg.dial()(dialCtx, chatID, workerID, requestHeader) + if err != nil { + // Don't log context-canceled errors + // since they are expected when a dial is + // superseded by a newer one. + if dialCtx.Err() == nil { + logger.Warn(ctx, "failed to open relay for message parts", + slog.F("chat_id", chatID), + slog.F("worker_id", workerID), + slog.Error(err), + ) + } + // Send an empty result so the merge loop + // can schedule a reconnect attempt. + select { + case relayReadyCh <- relayResult{workerID: workerID}: + case <-dialCtx.Done(): + } + return + } // If the dial context was canceled while the + // dial was in progress, discard the result to + // avoid starting a wrappedParts goroutine for + // a stale connection. + if dialCtx.Err() != nil { + cancel() + return + } + // Wrap the relay channel so snapshot parts + // are delivered through the same channel as + // live parts. This goroutine only forwards + // events — it does not own the relay + // lifecycle. When dialCtx is canceled it + // simply returns, closing wrappedParts via + // its defer. The cancel() is called by + // whoever canceled dialCtx (closeRelay or + // the send-fallback select below). + wrappedParts := make(chan codersdk.ChatStreamEvent, 128) + go func() { + defer close(wrappedParts) + for _, event := range snapshot { + if event.Type == codersdk.ChatStreamEventTypeMessagePart { + select { + case wrappedParts <- event: + case <-dialCtx.Done(): + return + } + } + } + for { + select { + case event, ok := <-parts: + if !ok { + return + } + select { + case wrappedParts <- event: + case <-dialCtx.Done(): + return + } + case <-dialCtx.Done(): + return + } + } + }() + select { + case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel, workerID: workerID}: + case <-dialCtx.Done(): + cancel() + } + }() + } + + // scheduleRelayReconnect arms a short timer so the select + // loop can re-check chat status and reopen the relay + // without spinning in a tight loop. + scheduleRelayReconnect := func() { + if cfg.dial() == nil { + return + } + if reconnectTimer != nil { + reconnectTimer.Stop() + } + reconnectTimer = cfg.clock().NewTimer(500*time.Millisecond, "reconnect") + reconnectCh = reconnectTimer.C + } + + statusNotifications := params.StatusNotifications + go func() { + defer close(mergedEvents) + defer closeRelay() + + // Forward any initial relay snapshot parts + // collected synchronously above. + for _, event := range initialRelaySnapshot { + select { + case <-ctx.Done(): + return + case mergedEvents <- event: + } + } + + for { + relayPartsCh := relayParts + select { + case <-ctx.Done(): + return + case result := <-relayReadyCh: + // Discard stale relay results from a + // previous dial that was superseded. + if result.workerID != expectedWorkerID { + if result.cancel != nil { + result.cancel() + } + continue + } + // A nil parts channel signals the dial + // failed — schedule a retry. + if result.parts == nil { + scheduleRelayReconnect() + continue + } + // An async relay dial completed; swap + // in the new relay channel. + if relayCancel != nil { + relayCancel() + } + relayParts = result.parts + relayCancel = result.cancel + case <-reconnectCh: + reconnectCh = nil + // Re-check whether the chat is still + // running on a remote worker before + // reconnecting. + currentChat, chatErr := params.DB.GetChatByID(ctx, chatID) + if chatErr != nil { + logger.Warn(ctx, "failed to get chat for relay reconnect", + slog.F("chat_id", chatID), + slog.Error(chatErr), + ) + // Retry on transient DB errors to + // avoid permanently stalling the + // stream. + scheduleRelayReconnect() + continue + } + if currentChat.Status == database.ChatStatusRunning && + currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != params.WorkerID { + openRelayAsync(currentChat.WorkerID.UUID) + } + case sn, ok := <-statusNotifications: + if !ok { + statusNotifications = nil + continue + } + if sn.Status == database.ChatStatusRunning && sn.WorkerID != uuid.Nil && sn.WorkerID != params.WorkerID { + openRelayAsync(sn.WorkerID) + } else { + closeRelay() + } + case event, ok := <-relayPartsCh: + if !ok { + if relayCancel != nil { + relayCancel() + relayCancel = nil + } + relayParts = nil + // Schedule reconnection instead of + // giving up. + scheduleRelayReconnect() + continue + } + // Only forward message_part events from + // relay. + if event.Type == codersdk.ChatStreamEventTypeMessagePart { + select { + case <-ctx.Done(): + return + case mergedEvents <- event: + } + } + } + } + }() + + // The cancel function tears down the relay state + // indirectly: the merge goroutine owns all relay state + // (reconnectTimer, relayCancel, dialCancel, etc.) and + // cleans it up via its defer closeRelay() when ctx is + // canceled. + cancel := func() { + for _, cancelFn := range allCancels { + if cancelFn != nil { + cancelFn() + } + } + } + return mergedEvents, cancel + } +} + +// dialRelay opens a WebSocket relay connection to the replica +// identified by workerID and returns a snapshot of buffered +// message_part events plus a live channel of subsequent events. +// It passes afterID=MaxInt64 so the remote replica skips the +// full message history snapshot, since the relay only needs +// live message_part events. +func dialRelay( + ctx context.Context, + chatID uuid.UUID, + workerID uuid.UUID, + requestHeader http.Header, + cfg MultiReplicaSubscribeConfig, + clk quartz.Clock, +) ( + snapshot []codersdk.ChatStreamEvent, + parts <-chan codersdk.ChatStreamEvent, + cancel func(), + err error, +) { + address, ok := cfg.ResolveReplicaAddress(ctx, workerID) + if !ok { + return nil, nil, nil, xerrors.New("worker replica not found") + } + + baseURL, err := url.Parse(address) + if err != nil { + return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err) + } + replicaID := cfg.ReplicaIDFn() + relayCtx, relayCancel := context.WithCancel(ctx) + sdkClient := codersdk.New(baseURL) + sdkClient.HTTPClient = cfg.ReplicaHTTPClient + sdkClient.SessionTokenProvider = relayHeaderTokenProvider{ + header: relayHeaders(requestHeader, replicaID), + } + sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID, &codersdk.StreamChatOptions{ + AfterID: ptr.Ref(int64(math.MaxInt64)), + }) + if err != nil { + relayCancel() + return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err) + } + + snapshot = make([]codersdk.ChatStreamEvent, 0, 100) + + // Wait briefly for the first event to handle the common + // case where the remote side has buffered parts but hasn't + // flushed them to the WebSocket yet. + const drainTimeout = time.Second + drainTimer := clk.NewTimer(drainTimeout, "drain") + defer drainTimer.Stop() + +drainInitial: + for len(snapshot) < cap(snapshot) { + select { + case <-relayCtx.Done(): + _ = sourceStream.Close() + relayCancel() + return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err()) + case event, ok := <-sourceEvents: + if !ok { + break drainInitial + } + if event.Type != codersdk.ChatStreamEventTypeMessagePart { + continue + } + snapshot = append(snapshot, event) + // After getting the first event, switch to + // non-blocking drain for remaining buffered events. + drainTimer.Stop() + drainTimer.Reset(0) + case <-drainTimer.C: + break drainInitial + } + } + + events := make(chan codersdk.ChatStreamEvent, 128) + + go func() { + defer close(events) + defer relayCancel() + defer func() { + _ = sourceStream.Close() + }() + + // No need to re-send snapshot events — they're + // returned to the caller directly. + for { + select { + case <-relayCtx.Done(): + return + case event, ok := <-sourceEvents: + if !ok { + return + } + if event.Type != codersdk.ChatStreamEventTypeMessagePart { + continue + } + select { + case events <- event: + case <-relayCtx.Done(): + return + } + } + } + }() + + cancelFn := func() { + relayCancel() + _ = sourceStream.Close() + } + return snapshot, events, cancelFn, nil +} + +type relayHeaderTokenProvider struct { + header http.Header +} + +func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption { + return func(req *http.Request) { + for key, values := range p.header { + for _, value := range values { + req.Header.Add(key, value) + } + } + } +} + +func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) { + if opts.HTTPHeader == nil { + opts.HTTPHeader = make(http.Header) + } + for key, values := range p.header { + for _, value := range values { + opts.HTTPHeader.Add(key, value) + } + } +} + +func (p relayHeaderTokenProvider) GetSessionToken() string { + return p.header.Get(codersdk.SessionTokenHeader) +} + +func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header { + header := make(http.Header) + if source != nil { + for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} { + for _, value := range source.Values(key) { + header.Add(key, value) + } + } + } + header.Set(RelaySourceHeader, replicaID.String()) + return header +} diff --git a/enterprise/coderd/chatd/chatd_test.go b/enterprise/coderd/chatd/chatd_test.go new file mode 100644 index 0000000000..65cd410349 --- /dev/null +++ b/enterprise/coderd/chatd/chatd_test.go @@ -0,0 +1,1123 @@ +package chatd_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + osschatd "github.com/coder/coder/v2/coderd/chatd" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/codersdk" + entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func newTestServer( + t *testing.T, + db database.Store, + ps dbpubsub.Pubsub, + replicaID uuid.UUID, + dialer func( + ctx context.Context, + chatID uuid.UUID, + workerID uuid.UUID, + requestHeader http.Header, + ) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + error, + ), + clock quartz.Clock, +) *osschatd.Server { + t.Helper() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := osschatd.New(osschatd.Config{ + Logger: logger, + Database: db, + ReplicaID: replicaID, + Pubsub: ps, + SubscribeFn: entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{DialerFn: dialer, Clock: clock}), + PendingChatAcquireInterval: testutil.WaitSuperLong, + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + return server +} + +// seedChatDependencies creates a user and chat model config in the +// database for use in relay tests. +func seedChatDependencies( + ctx context.Context, + t *testing.T, + db database.Store, +) (database.User, database.ChatModelConfig) { + t.Helper() + + user := dbgen.User(t, db, database.User{}) + _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + BaseUrl: "", + ApiKeyKeyID: sql.NullString{}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + }) + require.NoError(t, err) + model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "gpt-4o-mini", + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 70, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + return user, model +} + +func TestSubscribeRelayReconnectsOnDrop(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerID := uuid.New() + subscriberID := uuid.New() + + var callCount atomic.Int32 + + provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + ch := make(chan codersdk.ChatStreamEvent, 10) + if call == 1 { + // First relay: send a part then close to simulate a drop. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first-relay"}, + }, + } + close(ch) + } else { + // Second relay: send a different part, keep open. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "second-relay"}, + }, + } + // Don't close — keep alive so the subscriber stays connected. + } + return nil, ch, func() {}, nil + } + + mclk := quartz.NewMock(t) + // Trap the reconnect timer so we can fire it deterministically + // instead of waiting real time. + trapReconnect := mclk.Trap().NewTimer("reconnect") + defer trapReconnect.Close() + + subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat and mark it as running on a remote worker. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "relay-reconnect", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Should get the first relay part. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "first-relay" { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + // Wait for the reconnect timer to be created after the relay + // drop, then advance the mock clock to fire it immediately. + trapReconnect.MustWait(ctx).MustRelease(ctx) + mclk.Advance(500 * time.Millisecond).MustWait(ctx) + + // After the first relay closes, the reconnection should deliver + // the second relay part. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "second-relay" { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + require.GreaterOrEqual(t, int(callCount.Load()), 2) +} + +func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerID := uuid.New() + subscriberID := uuid.New() + + dialStarted := make(chan struct{}) + dialContinue := make(chan struct{}) + + provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + // Signal that the dial has started, then block until released. + select { + case <-dialStarted: + default: + close(dialStarted) + } + select { + case <-dialContinue: + case <-ctx.Done(): + return nil, nil, nil, ctx.Err() + } + ch := make(chan codersdk.ChatStreamEvent, 10) + return nil, ch, func() {}, nil + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat in pending status. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "relay-async-nonblock", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + // Subscribe before the chat is marked running so the relay opens + // via pubsub notification (openRelayAsync path). + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Now mark the chat as running on a remote worker. This publishes + // a status notification which triggers openRelayAsync on the + // subscriber. + notify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: workerID.String(), + } + payload, err := json.Marshal(notify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) + require.NoError(t, err) + + // Wait for the relay dial to actually start (blocking in the + // provider). + select { + case <-dialStarted: + case <-ctx.Done(): + t.Fatal("timed out waiting for relay dial to start") + } + + // While the relay is still dialing (provider is blocked), publish + // another status change. If openRelayAsync blocked the select loop + // this event would never arrive. + statusNotify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusWaiting), + } + statusPayload, err := json.Marshal(statusNotify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload) + require.NoError(t, err) + + // The waiting status event should arrive promptly despite the + // relay still dialing. + require.Eventually(t, func() bool { + select { + case event := <-events: + return event.Type == codersdk.ChatStreamEventTypeStatus && + event.Status != nil && + event.Status.Status == codersdk.ChatStatusWaiting + default: + return false + } + }, testutil.WaitShort, testutil.IntervalFast) + + // Unblock the relay dial so the test can clean up. + close(dialContinue) +} + +func TestSubscribeRelaySnapshotDelivered(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerID := uuid.New() + subscriberID := uuid.New() + + provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + // Return a non-empty snapshot with two parts. + snapshot := []codersdk.ChatStreamEvent{ + { + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-one"}, + }, + }, + { + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-two"}, + }, + }, + } + ch := make(chan codersdk.ChatStreamEvent, 10) + // Also send a live part after the snapshot. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "live-part"}, + }, + } + return snapshot, ch, func() {}, nil + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat already running on a remote worker. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "relay-snapshot", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + + initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // The relay snapshot parts are forwarded through the events + // channel by the enterprise SubscribeFn. Collect them along + // with the live part. + var receivedTexts []string + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil { + receivedTexts = append(receivedTexts, event.MessagePart.Part.Text) + } + // We expect snap-one, snap-two, and live-part. + return len(receivedTexts) >= 3 + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts) + + // The initial snapshot should still contain the status event + // from the OSS preamble. + var hasStatus bool + for _, event := range initialSnapshot { + if event.Type == codersdk.ChatStreamEventTypeStatus { + hasStatus = true + } + } + require.True(t, hasStatus, "initial snapshot should contain status event") +} + +// TestSubscribeRelayStaleDialDiscardedAfterInterrupt verifies that when a +// user interrupts a streaming chat and sends a new message (which gets +// picked up by a different replica), an in-flight relay dial to the +// OLD replica is canceled/discarded and the relay connects to the +// NEW replica correctly. +func TestSubscribeRelayStaleDialDiscardedAfterInterrupt(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + oldWorkerID := uuid.New() + newWorkerID := uuid.New() + subscriberID := uuid.New() + + // Gate to hold the first dial until we're ready. + firstDialStarted := make(chan struct{}) + releaseFirstDial := make(chan struct{}) + + var callCount atomic.Int32 + + provider := func(ctx context.Context, _ uuid.UUID, workerID uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + ch := make(chan codersdk.ChatStreamEvent, 10) + if call == 1 { + // First dial (to old worker): signal that we started, + // then block until released or context canceled. + close(firstDialStarted) + select { + case <-releaseFirstDial: + case <-ctx.Done(): + return nil, nil, nil, ctx.Err() + } + // If we get here after being released (not canceled), + // return a stale part — this should be discarded. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "stale-part"}, + }, + } + close(ch) + return nil, ch, func() {}, nil + } + // Second dial (to new worker): return a valid part. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "new-worker-part"}, + }, + } + return nil, ch, func() {}, nil + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "stale-dial-test", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + // Start chat in waiting state so Subscribe does NOT try an initial relay. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + }) + require.NoError(t, err) + + // Subscribe while chat is in "waiting" state — no relay opened. + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Now simulate the chat being picked up by the OLD worker via pubsub. + // This triggers openRelayAsync in the merge loop. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: oldWorkerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + oldRunningNotify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: oldWorkerID.String(), + } + oldRunningPayload, err := json.Marshal(oldRunningNotify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), oldRunningPayload) + require.NoError(t, err) + + // Wait for the first dial goroutine to start (it's blocked in the provider). + select { + case <-firstDialStarted: + case <-ctx.Done(): + t.Fatal("timed out waiting for first dial to start") + } + + // Simulate interrupt: chat goes to "waiting". + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + }) + require.NoError(t, err) + waitingNotify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusWaiting), + } + waitingPayload, err := json.Marshal(waitingNotify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), waitingPayload) + require.NoError(t, err) + + // Wait for the merge loop to process the waiting notification + // and emit the status event before publishing the new running + // notification. This avoids time.Sleep (banned by project + // policy) and provides a deterministic sync point. + require.Eventually(t, func() bool { + select { + case event := <-events: + return event.Type == codersdk.ChatStreamEventTypeStatus && + event.Status != nil && + event.Status.Status == codersdk.ChatStatusWaiting + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + // Now the chat transitions to running on the NEW worker. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: newWorkerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + runningNotify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: newWorkerID.String(), + } + runningPayload, err := json.Marshal(runningNotify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), runningPayload) + require.NoError(t, err) + + // Now release the first dial (if it wasn't already canceled). + close(releaseFirstDial) + + // The subscriber should receive parts from the NEW worker, not the stale one. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "new-worker-part" { + return true + } + // If we get the stale part, the bug is present. + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "stale-part" { + t.Fatal("received stale part from old worker — relay did not cancel in-flight dial") + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + // Drain the events channel for a while to ensure no late-arriving + // stale part sneaks in after the require.Eventually above returned. + // This closes the timing gap where "stale-part" could arrive after + // "new-worker-part" was already consumed. + require.Never(t, func() bool { + select { + case event := <-events: + return event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "stale-part" + default: + return false + } + }, 2*time.Second, testutil.IntervalFast) +} + +// TestSubscribeCancelDuringInFlightDial verifies that calling the +// subscription's cancel function while a relay dial goroutine is +// still blocking in the provider causes the provider's context to +// be canceled and the goroutine to return cleanly. +func TestSubscribeCancelDuringInFlightDial(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerID := uuid.New() + subscriberID := uuid.New() + + dialStarted := make(chan struct{}) + dialExited := make(chan struct{}) + + provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + // Signal the dial has started, then block until the context + // is canceled. + close(dialStarted) + <-ctx.Done() + close(dialExited) + return nil, nil, nil, ctx.Err() + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "cancel-inflight-dial", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + // Put the chat in waiting state so Subscribe does not open a + // synchronous relay. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + }) + require.NoError(t, err) + + _, _, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + + // Publish a running notification to trigger openRelayAsync. + notify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: workerID.String(), + } + payload, err := json.Marshal(notify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) + require.NoError(t, err) + + // Wait for the dial goroutine to block inside the provider. + select { + case <-dialStarted: + case <-ctx.Done(): + t.Fatal("timed out waiting for dial to start") + } + + // Cancel the subscription while the dial is still in-flight. + cancel() + + // The provider context must be canceled, causing the goroutine + // to return cleanly. + require.Eventually(t, func() bool { + select { + case <-dialExited: + return true + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) +} + +// TestSubscribeRelayRunningToRunningSwitch verifies that when a chat +// transitions directly from running(workerA) to running(workerB) +// without an intermediate waiting state, the relay switches to the +// new worker and discards parts from the old one. +func TestSubscribeRelayRunningToRunningSwitch(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerA := uuid.New() + workerB := uuid.New() + subscriberID := uuid.New() + + // Gate to hold workerA's dial until we verify cancellation. + dialAStarted := make(chan struct{}) + dialAExited := make(chan struct{}) + + var callCount atomic.Int32 + + provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + if call == 1 { + // First dial (to workerA): signal that we started, + // then block until the context is canceled. + close(dialAStarted) + <-ctx.Done() + close(dialAExited) + return nil, nil, nil, ctx.Err() + } + // Second dial (to workerB): return a valid part. + ch := make(chan codersdk.ChatStreamEvent, 10) + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "worker-b-part"}, + }, + } + return nil, ch, func() {}, nil + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "running-to-running", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + // Start in waiting state so Subscribe does not open a relay. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + }) + require.NoError(t, err) + + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Transition to running on workerA. + notifyA := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: workerA.String(), + } + payloadA, err := json.Marshal(notifyA) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadA) + require.NoError(t, err) + + // Wait for the workerA dial goroutine to block inside the + // provider before publishing the workerB notification. + select { + case <-dialAStarted: + case <-ctx.Done(): + t.Fatal("timed out waiting for workerA dial to start") + } + + // Immediately transition to running on workerB (no waiting in + // between). This should cancel workerA's in-flight dial. + notifyB := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: workerB.String(), + } + payloadB, err := json.Marshal(notifyB) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadB) + require.NoError(t, err) + + // Verify that the relay canceled workerA's stale dial. + require.Eventually(t, func() bool { + select { + case <-dialAExited: + return true + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + // We should receive the part from workerB. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "worker-b-part" { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + require.Equal(t, 2, int(callCount.Load())) +} + +// TestSubscribeRelayFailedDialRetries verifies that when an async relay +// dial fails (returns an error), the merge loop schedules a reconnect +// timer and eventually re-dials successfully. This exercises the +// result.parts == nil path and the scheduleRelayReconnect() logic. +func TestSubscribeRelayFailedDialRetries(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + remoteWorkerID := uuid.New() + subscriberID := uuid.New() + + var callCount atomic.Int32 + + provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + if call == 1 { + // First dial: fail with an error to trigger + // scheduleRelayReconnect via the result.parts == nil path. + return nil, nil, nil, xerrors.New("transient dial failure") + } + // Second dial: succeed and return a part. + ch := make(chan codersdk.ChatStreamEvent, 10) + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "retry-success"}, + }, + } + return nil, ch, func() {}, nil + } + + mclk := quartz.NewMock(t) + // Trap the reconnect timer so we can fire it deterministically. + trapReconnect := mclk.Trap().NewTimer("reconnect") + defer trapReconnect.Close() + + subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat in waiting state so Subscribe does not open a + // synchronous relay. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "failed-dial-retry", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + // Keep the chat in waiting state so Subscribe does not attempt + // a synchronous relay dial. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + }) + require.NoError(t, err) + + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Now mark the chat as running on the remote worker in the DB. + // The reconnect timer calls params.DB.GetChatByID to check if + // the chat is still running on a remote worker, so this must be + // set before we advance the clock. + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: remoteWorkerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + + // Publish a running notification with a remote workerID to + // trigger openRelayAsync. The first dial will fail, causing + // scheduleRelayReconnect to be called. + notify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: remoteWorkerID.String(), + } + payload, err := json.Marshal(notify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) + require.NoError(t, err) + + // Wait for the reconnect timer to be created (after the failed + // dial), then advance the mock clock to fire it. + trapReconnect.MustWait(ctx).MustRelease(ctx) + mclk.Advance(500 * time.Millisecond).MustWait(ctx) + + // The merge loop re-checks the DB, sees the chat is still + // running on the remote worker, and dials again. The second + // dial succeeds. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "retry-success" { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + require.GreaterOrEqual(t, int(callCount.Load()), 2) +} + +// TestSubscribeRunningLocalWorkerClosesRelay verifies that when a chat +// is running on a remote worker and a pubsub notification arrives +// saying the local worker (subscriberID) now owns the chat, the +// existing relay is closed and no new dial is started (the local +// worker serves directly without relaying). +func TestSubscribeRunningLocalWorkerClosesRelay(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + remoteWorkerID := uuid.New() + subscriberID := uuid.New() + + var callCount atomic.Int32 + + provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + ch := make(chan codersdk.ChatStreamEvent, 10) + if call == 1 { + // Initial synchronous dial to the remote worker. + ch <- codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "remote-part"}, + }, + } + // Keep channel open so the relay stays active. + } + return nil, ch, func() {}, nil + } + + subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create the chat already running on a remote worker so Subscribe + // opens a synchronous relay. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "local-worker-closes-relay", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: remoteWorkerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Consume the remote-part from the initial relay. + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == "remote-part" { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + + // Notify that the LOCAL worker now owns the chat. This should + // close the relay without opening a new one. + notify := coderdpubsub.ChatStreamNotifyMessage{ + Status: string(database.ChatStatusRunning), + WorkerID: subscriberID.String(), + } + payload, err := json.Marshal(notify) + require.NoError(t, err) + err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) + require.NoError(t, err) + + // Give the system time to process the notification. No additional + // dial should happen — only the initial synchronous one. + require.Never(t, func() bool { + return int(callCount.Load()) > 1 + }, 2*time.Second, testutil.IntervalFast) + + require.Equal(t, 1, int(callCount.Load()), + "only the initial synchronous dial should have happened") +} + +// TestSubscribeRelayMultipleReconnects verifies that the reconnect +// loop handles multiple consecutive relay drops, proving it is +// robust across repeated iterations — not just the single reconnect +// already covered by TestSubscribeRelayReconnectsOnDrop. +func TestSubscribeRelayMultipleReconnects(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + workerID := uuid.New() + subscriberID := uuid.New() + + var callCount atomic.Int32 + + provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( + []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, + ) { + call := callCount.Add(1) + ch := make(chan codersdk.ChatStreamEvent, 10) + part := codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: "assistant", + Part: codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeText, + Text: fmt.Sprintf("relay-%d", call), + }, + }, + } + ch <- part + if call <= 2 { + // First two dials: close channel to simulate relay + // drop. This triggers scheduleRelayReconnect. + close(ch) + } + // Third dial: keep channel open. + return nil, ch, func() {}, nil + } + + mclk := quartz.NewMock(t) + // Trap the reconnect timer so we can fire both reconnects + // deterministically. + trapReconnect := mclk.Trap().NewTimer("reconnect") + defer trapReconnect.Close() + + subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Create a chat already running on a remote worker so + // Subscribe opens a synchronous relay immediately. + chat, err := subscriber.CreateChat(ctx, osschatd.CreateOptions{ + OwnerID: user.ID, + Title: "multiple-reconnects", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}}, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + require.NoError(t, err) + + _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + // Helper to consume a specific relay part. + consumePart := func(text string) { + t.Helper() + require.Eventually(t, func() bool { + select { + case event := <-events: + if event.Type == codersdk.ChatStreamEventTypeMessagePart && + event.MessagePart != nil && + event.MessagePart.Part.Text == text { + return true + } + return false + default: + return false + } + }, testutil.WaitMedium, testutil.IntervalFast) + } + + // First relay: consumed immediately (synchronous dial). + consumePart("relay-1") + + // First relay drops → reconnect timer created. Advance clock + // to fire it. + trapReconnect.MustWait(ctx).MustRelease(ctx) + mclk.Advance(500 * time.Millisecond).MustWait(ctx) + + // Second relay part. + consumePart("relay-2") + + // Second relay drops → another reconnect timer. Advance again. + trapReconnect.MustWait(ctx).MustRelease(ctx) + mclk.Advance(500 * time.Millisecond).MustWait(ctx) + + // Third relay part (channel stays open). + consumePart("relay-3") + require.GreaterOrEqual(t, int(callCount.Load()), 3) +} diff --git a/enterprise/coderd/chats.go b/enterprise/coderd/chats.go deleted file mode 100644 index 204eac588d..0000000000 --- a/enterprise/coderd/chats.go +++ /dev/null @@ -1,177 +0,0 @@ -package coderd - -import ( - "context" - "net/http" - "net/url" - "time" - - "github.com/google/uuid" - "golang.org/x/xerrors" - - "github.com/coder/coder/v2/coderd/chatd" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/websocket" -) - -// RelaySourceHeader marks replica-relayed stream requests. -const RelaySourceHeader = "X-Coder-Relay-Source-Replica" - -const ( - authorizationHeader = "Authorization" - cookieHeader = "Cookie" -) - -// newRemotePartsProvider creates a RemotePartsProvider that dials a remote -// replica's stream endpoint to fetch message_part events. It filters to only -// forward message_part events since durable events come via pubsub. -func newRemotePartsProvider( - resolveReplicaAddress func(context.Context, uuid.UUID) (string, bool), - replicaHTTPClient *http.Client, - replicaID uuid.UUID, -) chatd.RemotePartsProvider { - return func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - address, ok := resolveReplicaAddress(ctx, workerID) - if !ok { - return nil, nil, nil, xerrors.New("worker replica not found") - } - - baseURL, err := url.Parse(address) - if err != nil { - return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err) - } - relayCtx, relayCancel := context.WithCancel(ctx) - sdkClient := codersdk.New(baseURL) - sdkClient.HTTPClient = replicaHTTPClient - sdkClient.SessionTokenProvider = relayHeaderTokenProvider{ - header: relayHeaders(requestHeader, replicaID), - } - sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID) - if err != nil { - relayCancel() - return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err) - } - - snapshot := make([]codersdk.ChatStreamEvent, 0, 100) - - // Wait briefly for the first event to handle the common - // case where the remote side has buffered parts but hasn't - // flushed them to the WebSocket yet. - const drainTimeout = time.Second - drainTimer := time.NewTimer(drainTimeout) - defer drainTimer.Stop() - - drainInitial: - for len(snapshot) < cap(snapshot) { - select { - case <-relayCtx.Done(): - _ = sourceStream.Close() - relayCancel() - return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err()) - case event, ok := <-sourceEvents: - if !ok { - break drainInitial - } - if event.Type != codersdk.ChatStreamEventTypeMessagePart { - continue - } - snapshot = append(snapshot, event) - // After getting the first event, switch to - // non-blocking drain for remaining buffered events. - drainTimer.Stop() - drainTimer.Reset(0) - case <-drainTimer.C: - break drainInitial - } - } - - events := make(chan codersdk.ChatStreamEvent, 128) - - go func() { - defer close(events) - defer relayCancel() - defer func() { - _ = sourceStream.Close() - }() - - // No need to re-send snapshot events — they're - // returned to the caller directly. - for { - select { - case <-relayCtx.Done(): - return - case event, ok := <-sourceEvents: - if !ok { - return - } - if event.Type != codersdk.ChatStreamEventTypeMessagePart { - continue - } - select { - case events <- event: - case <-relayCtx.Done(): - return - } - } - } - }() - - cancel := func() { - relayCancel() - _ = sourceStream.Close() - } - return snapshot, events, cancel, nil - } -} - -type relayHeaderTokenProvider struct { - header http.Header -} - -func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption { - return func(req *http.Request) { - for key, values := range p.header { - for _, value := range values { - req.Header.Add(key, value) - } - } - } -} - -func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) { - if opts.HTTPHeader == nil { - opts.HTTPHeader = make(http.Header) - } - for key, values := range p.header { - for _, value := range values { - opts.HTTPHeader.Add(key, value) - } - } -} - -func (p relayHeaderTokenProvider) GetSessionToken() string { - return p.header.Get(codersdk.SessionTokenHeader) -} - -func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header { - header := make(http.Header) - if source != nil { - for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} { - for _, value := range source.Values(key) { - header.Add(key, value) - } - } - } - header.Set(RelaySourceHeader, replicaID.String()) - return header -} diff --git a/enterprise/coderd/chats_test.go b/enterprise/coderd/chats_test.go index d920e6d4c6..9a45544aaf 100644 --- a/enterprise/coderd/chats_test.go +++ b/enterprise/coderd/chats_test.go @@ -131,7 +131,7 @@ func TestChatStreamRelay(t *testing.T) { ) } - firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID) + firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer firstStream.Close() @@ -151,7 +151,7 @@ func TestChatStreamRelay(t *testing.T) { firstEvent := waitForStreamTextPart(ctx, t, firstEvents, firstChunkText) require.Equal(t, "assistant", firstEvent.MessagePart.Role) - secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID) + secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer secondStream.Close() @@ -277,7 +277,7 @@ func TestChatStreamRelay(t *testing.T) { // Subscribe on the local (worker) replica so the stream is // consumed and chunks flow through the pipeline. - localEvents, localStream, err := localClient.StreamChat(ctx, chat.ID) + localEvents, localStream, err := localClient.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer localStream.Close() @@ -308,7 +308,7 @@ func TestChatStreamRelay(t *testing.T) { // NOW connect the relay subscriber on the non-worker replica. // The relay must pick up all three buffered parts in its // initial snapshot via the drainInitial loop. - relayEvents, relayStream, err := relayClient.StreamChat(ctx, chat.ID) + relayEvents, relayStream, err := relayClient.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer relayStream.Close() diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index e0ee96b3b1..caf8baae63 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -45,6 +45,7 @@ import ( agplusage "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/codersdk" + entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/dbauthz" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" @@ -191,8 +192,9 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { // This must happen before coderd initialization! options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader - // Wire up enterprise chat relay for cross-replica message_part streaming. - // Must be set before coderd.New so the chat processor gets it. + // Wire up enterprise chat subscription with cross-replica relay + // and pubsub coordination. Must be set before coderd.New so the + // chat processor receives it. replicaHTTPClient := replicaRelayHTTPClient(options.HTTPClient, meshTLSConfig) if replicaHTTPClient == nil { replicaHTTPClient = options.Options.HTTPClient @@ -200,33 +202,20 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { if replicaHTTPClient == nil { replicaHTTPClient = http.DefaultClient } - // Use a closure that captures api by reference so it can access api.AGPL.ID - // after coderd.New is called. The provider is only invoked when Subscribe - // is called, which happens after initialization, so api.AGPL will be set. - options.Options.ChatRemotePartsProvider = func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - // Get the replica ID from the API (will be set after coderd.New) - replicaID := api.AGPL.ID - if replicaID == uuid.Nil { - // Fallback if somehow called before initialization - replicaID = uuid.New() - } - provider := newRemotePartsProvider( - resolveReplicaAddress, - replicaHTTPClient, - replicaID, - ) - return provider(ctx, chatID, workerID, requestHeader) - } + // Use a closure that captures api by reference so it can access + // api.AGPL.ID after coderd.New is called. The SubscribeFn is + // only invoked from Subscribe, which happens after init. + options.Options.ChatSubscribeFn = entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{ + ResolveReplicaAddress: resolveReplicaAddress, + ReplicaHTTPClient: replicaHTTPClient, + ReplicaIDFn: func() uuid.UUID { + id := api.AGPL.ID + if id == uuid.Nil { + return uuid.New() + } + return id + }, + }) api.AGPL = coderd.New(options.Options) defer func() { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 1198f206de..698092b966 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -5432,6 +5432,20 @@ export interface StatsCollectionConfig { readonly usage_stats: UsageStatsConfig; } +// From codersdk/chats.go +/** + * StreamChatOptions are optional parameters for StreamChat. + */ +export interface StreamChatOptions { + /** + * AfterID limits the initial snapshot to messages created + * after the given ID. This is useful for relay connections + * that only need live message_part events and can skip the + * full message history. + */ + readonly AfterID: number | null; +} + // From codersdk/client.go /** * SubdomainAppSessionTokenCookie is the name of the cookie that stores an