mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(chatd): fix relay race conditions, extract enterprise relay logic, move pubsub to OSS (#22589)
## Summary Fixes a bug where interrupting a streaming chat and sending a new message left the relay connected to the wrong replica. Expanded into a broader refactor that cleanly separates concerns: - **OSS** owns pubsub subscription, message catch-up, queue updates, status forwarding, and local parts merging. - **Enterprise** (`enterprise/coderd/chatd`) only manages relay dialing, reconnection, and stale-dial discarding for cross-replica streaming. ## Architecture ### OSS `coderd/chatd/chatd.go` `Subscribe()` builds the initial snapshot then runs a single merge goroutine that handles: - Pubsub subscription for durable events (status, messages, queue, errors) - Message catch-up via `AfterMessageID` - Local `message_part` forwarding - Relay events from enterprise (when `SubscribeFn` is set) - Sends `StatusNotification` to enterprise so it can manage relay lifecycle Key types: - `SubscribeFn` — enterprise hook, returns relay-only events channel - `SubscribeFnParams` — `ChatID`, `Chat`, `WorkerID`, `StatusNotifications`, `RequestHeader`, `DB`, `Logger` - `StatusNotification` — `Status` + `WorkerID`, sent to enterprise on pubsub status changes ### Enterprise `enterprise/coderd/chatd/chatd.go` `NewMultiReplicaSubscribeFn(cfg MultiReplicaSubscribeConfig)` returns a `SubscribeFn` that: - Opens an initial synchronous relay if the chat is running on a remote worker - Reads `StatusNotifications` from OSS to open/close relay connections - Handles async dial, reconnect timers, stale-dial discarding - Returns only relay `message_part` events ## Bug fixes ### Original bug: stale relay dial after interrupt `openRelayAsync` goroutines used `mergedCtx` (subscription-level), not a per-dial context. `closeRelay()` could not cancel in-flight dials. When the user interrupts and a new replica picks up the chat, the old dial goroutine could complete after the new one and deliver a stale `relayResult`. **Fix**: per-dial `dialCtx`/`dialCancel`, `expectedWorkerID` tracking, `workerID` on `relayResult`. `closeRelay()` cancels the dial context and drains `relayReadyCh`. Merge loop rejects mismatched worker IDs. ### Additional fixes - `statusNotifications` send-on-closed-channel race — goroutine now owns `close()` via defer - Enterprise spin-loop on `StatusNotifications` close — two-value receive with nil-out - `hasPubsub` set from `p.pubsub != nil` instead of subscription success — now tracks actual subscription result - `lastMessageID` not initialized from `afterMessageID` — caused duplicate messages on catch-up - `wrappedParts` goroutine leaked remote connection on `dialCtx` cancel - `closeRelay()` did not drain `relayReadyCh` - `setChatWaiting` race with `SendMessage(Interrupt)` — wrapped in `InTx` - `processChat` post-TX side effects fired when chat was taken by another worker — added `errChatTakenByOtherWorker` sentinel - Cancel closure data race on `reconnectTimer` - Bare blocking send on pubsub error path - `localParts` hot-spin after channel close - No-pubsub branch dropped relay events and initial snapshot - Failed relay dial caused permanent stall (no reconnect retry) - DB error during reconnect timer caused permanent stall - `time.NewTimer` replaced with `quartz.Clock` for testable timing ## Tests 9 enterprise tests covering: - Relay reconnect on drop (mock clock) - Async dial does not block merge loop - Relay snapshot delivery - Stale dial discarded after interrupt - Cancel during in-flight dial - Running-to-running worker switch - Failed dial retries (mock clock) - Local worker closes relay - Multiple consecutive reconnects (mock clock) All pass with `-race`.
This commit is contained in:
+287
-301
@@ -62,7 +62,7 @@ type Server struct {
|
||||
workerID uuid.UUID
|
||||
logger slog.Logger
|
||||
|
||||
remotePartsProvider RemotePartsProvider
|
||||
subscribeFn SubscribeFn
|
||||
|
||||
agentConnFn AgentConnFunc
|
||||
createWorkspaceFn chattool.CreateWorkspaceFn
|
||||
@@ -93,24 +93,41 @@ type cachedInstruction struct {
|
||||
// AgentConnFunc provides access to workspace agent connections.
|
||||
type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error)
|
||||
|
||||
// ReplicaAddressResolver maps a replica ID to its relay address.
|
||||
type ReplicaAddressResolver func(context.Context, uuid.UUID) (string, bool)
|
||||
|
||||
// RemotePartsProvider returns a snapshot and live stream of message_part
|
||||
// events from the replica that is running the chat. Called when the chat
|
||||
// is actively running on a different replica. Nil in AGPL single-replica
|
||||
// deployments.
|
||||
type RemotePartsProvider func(
|
||||
// SubscribeFn replaces the default local-only subscription with a
|
||||
// multi-replica-aware implementation that merges pubsub notifications,
|
||||
// remote relay streams, and local parts into a single event channel.
|
||||
// When set, Subscribe delegates the event-merge goroutine to this
|
||||
// function instead of using simple local forwarding.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: subscription lifetime context (canceled on unsubscribe).
|
||||
// - params: all state needed to build the merged stream.
|
||||
//
|
||||
// Returns the merged event channel and a cleanup function.
|
||||
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
|
||||
type SubscribeFn func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
snapshot []codersdk.ChatStreamEvent,
|
||||
parts <-chan codersdk.ChatStreamEvent,
|
||||
cancel func(),
|
||||
err error,
|
||||
)
|
||||
params SubscribeFnParams,
|
||||
) (<-chan codersdk.ChatStreamEvent, func())
|
||||
|
||||
// StatusNotification informs the enterprise relay manager of chat
|
||||
// status changes so it can open or close relay connections.
|
||||
type StatusNotification struct {
|
||||
Status database.ChatStatus
|
||||
WorkerID uuid.UUID
|
||||
}
|
||||
|
||||
// SubscribeFnParams carries the state that the enterprise
|
||||
// SubscribeFn implementation needs from the OSS Subscribe preamble.
|
||||
type SubscribeFnParams struct {
|
||||
ChatID uuid.UUID
|
||||
Chat database.Chat
|
||||
WorkerID uuid.UUID
|
||||
StatusNotifications <-chan StatusNotification
|
||||
RequestHeader http.Header
|
||||
DB database.Store
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
type chatStreamState struct {
|
||||
buffer []codersdk.ChatStreamEvent
|
||||
@@ -129,6 +146,12 @@ var (
|
||||
ErrEditedMessageNotFound = xerrors.New("edited message not found")
|
||||
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
|
||||
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
|
||||
|
||||
// errChatTakenByOtherWorker is a sentinel used inside the
|
||||
// processChat cleanup transaction to signal that another
|
||||
// worker acquired the chat, so all post-TX side effects
|
||||
// (status publish, pubsub, web push) must be skipped.
|
||||
errChatTakenByOtherWorker = xerrors.New("chat acquired by another worker")
|
||||
)
|
||||
|
||||
// CreateOptions controls chat creation in the shared chat mutation path.
|
||||
@@ -719,14 +742,31 @@ func setChatPendingWithStore(
|
||||
}
|
||||
|
||||
func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) {
|
||||
updatedChat, err := p.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chatID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
var updatedChat database.Chat
|
||||
err := p.db.InTx(func(tx database.Store) error {
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if lockErr != nil {
|
||||
return xerrors.Errorf("lock chat for waiting: %w", lockErr)
|
||||
}
|
||||
// If the chat has already transitioned to pending (e.g.
|
||||
// SendMessage with interrupt behavior), don't overwrite
|
||||
// it — the pending status takes priority so the new
|
||||
// message gets processed.
|
||||
if locked.Status == database.ChatStatusPending {
|
||||
updatedChat = locked
|
||||
return nil
|
||||
}
|
||||
var updateErr error
|
||||
updatedChat, updateErr = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chatID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
return updateErr
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
@@ -807,7 +847,7 @@ type Config struct {
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
ReplicaID uuid.UUID
|
||||
RemotePartsProvider RemotePartsProvider
|
||||
SubscribeFn SubscribeFn
|
||||
PendingChatAcquireInterval time.Duration
|
||||
InFlightChatStaleAfter time.Duration
|
||||
AgentConn AgentConnFunc
|
||||
@@ -844,7 +884,7 @@ func New(cfg Config) *Server {
|
||||
db: cfg.Database,
|
||||
workerID: workerID,
|
||||
logger: cfg.Logger.Named("chat-processor"),
|
||||
remotePartsProvider: cfg.RemotePartsProvider,
|
||||
subscribeFn: cfg.SubscribeFn,
|
||||
agentConnFn: cfg.AgentConn,
|
||||
createWorkspaceFn: cfg.CreateWorkspace,
|
||||
pubsub: cfg.Pubsub,
|
||||
@@ -954,10 +994,12 @@ func (p *Server) subscribeToStream(chatID uuid.UUID) (
|
||||
p.streamMu.Lock()
|
||||
state, ok := p.chatStreams[chatID]
|
||||
if ok {
|
||||
if subscriber, exists := state.subscribers[id]; exists {
|
||||
delete(state.subscribers, id)
|
||||
close(subscriber)
|
||||
}
|
||||
// Remove the subscriber but do not close the channel.
|
||||
// publishToStream copies subscriber references under
|
||||
// streamMu then sends outside the lock; closing here
|
||||
// races with that send and can panic. The channel
|
||||
// becomes unreachable once removed and will be GC'd.
|
||||
delete(state.subscribers, id)
|
||||
p.cleanupStreamIfIdleLocked(chatID, state)
|
||||
}
|
||||
p.streamMu.Unlock()
|
||||
@@ -1005,7 +1047,7 @@ func (p *Server) Subscribe(
|
||||
// Subscribe to local stream for message_parts (ephemeral).
|
||||
localSnapshot, localParts, localCancel := p.subscribeToStream(chatID)
|
||||
|
||||
// Build initial snapshot synchronously
|
||||
// Build initial snapshot synchronously.
|
||||
initialSnapshot := make([]codersdk.ChatStreamEvent, 0)
|
||||
// Add local message_parts to snapshot
|
||||
for _, event := range localSnapshot {
|
||||
@@ -1033,7 +1075,7 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
// Load initial queue
|
||||
// Load initial queue.
|
||||
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
||||
if err == nil && len(queued) > 0 {
|
||||
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
||||
@@ -1043,24 +1085,8 @@ func (p *Server) Subscribe(
|
||||
})
|
||||
}
|
||||
|
||||
// Get initial chat state to determine if we need a relay
|
||||
// Get initial chat state to determine if we need a relay.
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
var relayCancel func()
|
||||
var relayParts <-chan codersdk.ChatStreamEvent
|
||||
if err == nil && chat.Status == database.ChatStatusRunning && chat.WorkerID.Valid && chat.WorkerID.UUID != p.workerID && p.remotePartsProvider != nil {
|
||||
// Open relay for initial snapshot
|
||||
snapshot, parts, cancel, err := p.remotePartsProvider(ctx, chatID, chat.WorkerID.UUID, requestHeader)
|
||||
if err == nil {
|
||||
relayCancel = cancel
|
||||
relayParts = parts
|
||||
// Add relay message_parts to snapshot
|
||||
for _, event := range snapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
initialSnapshot = append(initialSnapshot, event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Include the current chat status in the snapshot so the
|
||||
// frontend can gate message_part processing correctly from
|
||||
@@ -1079,119 +1105,38 @@ func (p *Server) Subscribe(
|
||||
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
||||
}
|
||||
|
||||
// Track the last message ID we've seen for DB queries
|
||||
var lastMessageID int64
|
||||
// Track the last message ID we've seen for DB queries.
|
||||
// Initialize from afterMessageID so that when the caller passes
|
||||
// afterMessageID > 0 but no new messages exist yet, the first
|
||||
// pubsub catch-up doesn't re-fetch already-seen messages.
|
||||
lastMessageID := afterMessageID
|
||||
if len(messages) > 0 {
|
||||
lastMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// Merge all event sources
|
||||
// Merge all event sources.
|
||||
mergedCtx, mergedCancel := context.WithCancel(ctx)
|
||||
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
||||
|
||||
var allCancels []func()
|
||||
allCancels = append(allCancels, localCancel)
|
||||
if relayCancel != nil {
|
||||
allCancels = append(allCancels, relayCancel)
|
||||
}
|
||||
|
||||
// Channel for async relay establishment.
|
||||
type relayResult struct {
|
||||
parts <-chan codersdk.ChatStreamEvent
|
||||
cancel func()
|
||||
}
|
||||
relayReadyCh := make(chan relayResult, 1)
|
||||
|
||||
// Reconnect timer state.
|
||||
var reconnectTimer *time.Timer
|
||||
var reconnectCh <-chan time.Time
|
||||
|
||||
// Helper to close relay and stop any pending reconnect timer.
|
||||
closeRelay := func() {
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
relayCancel = nil
|
||||
}
|
||||
relayParts = nil
|
||||
if reconnectTimer != nil {
|
||||
reconnectTimer.Stop()
|
||||
reconnectTimer = nil
|
||||
reconnectCh = nil
|
||||
}
|
||||
}
|
||||
|
||||
// openRelayAsync dials the remote replica in a background
|
||||
// goroutine and delivers the result on relayReadyCh so the
|
||||
// main select loop is never blocked by network I/O.
|
||||
openRelayAsync := func(workerID uuid.UUID) {
|
||||
if p.remotePartsProvider == nil {
|
||||
return
|
||||
}
|
||||
closeRelay()
|
||||
go func() {
|
||||
snapshot, parts, cancel, err := p.remotePartsProvider(mergedCtx, chatID, workerID, requestHeader)
|
||||
if err != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to open relay for message parts",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("worker_id", workerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
// Wrap the relay channel so snapshot parts are
|
||||
// delivered through the same channel as live parts.
|
||||
wrappedParts := make(chan codersdk.ChatStreamEvent, 128)
|
||||
go func() {
|
||||
defer close(wrappedParts)
|
||||
for _, event := range snapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case wrappedParts <- event:
|
||||
case <-mergedCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
for event := range parts {
|
||||
select {
|
||||
case wrappedParts <- event:
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel}:
|
||||
case <-mergedCtx.Done():
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// scheduleRelayReconnect arms a short timer so the select
|
||||
// loop can re-check chat status and reopen the relay without
|
||||
// spinning in a tight loop.
|
||||
scheduleRelayReconnect := func() {
|
||||
if p.remotePartsProvider == nil {
|
||||
return
|
||||
}
|
||||
if reconnectTimer != nil {
|
||||
reconnectTimer.Stop()
|
||||
}
|
||||
reconnectTimer = time.NewTimer(500 * time.Millisecond)
|
||||
reconnectCh = reconnectTimer.C
|
||||
}
|
||||
|
||||
//nolint:nestif
|
||||
// Subscribe to pubsub for durable events (status, messages,
|
||||
// queue updates, errors). When pubsub is nil (e.g. in-memory
|
||||
// single-instance) we skip this and deliver all local events.
|
||||
var notifications <-chan coderdpubsub.ChatStreamNotifyMessage
|
||||
var errCh <-chan error
|
||||
if p.pubsub != nil {
|
||||
notifications := make(chan coderdpubsub.ChatStreamNotifyMessage, 10)
|
||||
errCh := make(chan error, 1)
|
||||
notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10)
|
||||
errNotifyCh := make(chan error, 1)
|
||||
notifications = notifyCh
|
||||
errCh = errNotifyCh
|
||||
|
||||
listener := func(_ context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
listener := func(_ context.Context, message []byte, listenErr error) {
|
||||
if listenErr != nil {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
case errCh <- err:
|
||||
case errNotifyCh <- listenErr:
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1199,187 +1144,214 @@ func (p *Server) Subscribe(
|
||||
if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr):
|
||||
case errNotifyCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr):
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
case notifications <- notify:
|
||||
case notifyCh <- notify:
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to pubsub for durable events
|
||||
if pubsubCancel, err := p.pubsub.SubscribeWithErr(
|
||||
if pubsubCancel, pubsubErr := p.pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
||||
listener,
|
||||
); err == nil {
|
||||
); pubsubErr == nil {
|
||||
allCancels = append(allCancels, pubsubCancel)
|
||||
} else {
|
||||
p.logger.Warn(mergedCtx, "failed to subscribe to chat stream notifications",
|
||||
p.logger.Warn(ctx, "failed to subscribe to chat stream notifications",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
slog.Error(pubsubErr),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle pubsub notifications in a goroutine
|
||||
go func() {
|
||||
defer close(mergedEvents)
|
||||
defer closeRelay()
|
||||
// When an enterprise SubscribeFn is provided and the chat
|
||||
// lookup succeeded, call it to get relay events (message_parts
|
||||
// from remote replicas). OSS now owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding;
|
||||
// enterprise only manages relay dialing.
|
||||
var relayEvents <-chan codersdk.ChatStreamEvent
|
||||
var relayCleanup func()
|
||||
var statusNotifications chan StatusNotification
|
||||
if p.subscribeFn != nil && err == nil {
|
||||
statusNotifications = make(chan StatusNotification, 10)
|
||||
var relayEvCh <-chan codersdk.ChatStreamEvent
|
||||
relayEvCh, relayCleanup = p.subscribeFn(mergedCtx, SubscribeFnParams{
|
||||
ChatID: chatID,
|
||||
Chat: chat,
|
||||
WorkerID: p.workerID,
|
||||
StatusNotifications: statusNotifications,
|
||||
RequestHeader: requestHeader,
|
||||
DB: p.db,
|
||||
Logger: p.logger,
|
||||
})
|
||||
relayEvents = relayEvCh
|
||||
}
|
||||
|
||||
for {
|
||||
relayPartsCh := relayParts
|
||||
hasPubsub := false
|
||||
if p.pubsub != nil {
|
||||
// hasPubsub is only true when we actually subscribed
|
||||
// successfully above (allCancels will contain the pubsub
|
||||
// cancel func in that case).
|
||||
hasPubsub = len(allCancels) > 1
|
||||
}
|
||||
|
||||
//nolint:nestif
|
||||
go func() {
|
||||
defer close(mergedEvents)
|
||||
if statusNotifications != nil {
|
||||
defer close(statusNotifications)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case psErr := <-errCh:
|
||||
p.logger.Error(mergedCtx, "chat stream pubsub error",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(psErr),
|
||||
)
|
||||
select {
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{
|
||||
Message: psErr.Error(),
|
||||
},
|
||||
}:
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case err := <-errCh:
|
||||
p.logger.Error(mergedCtx, "chat stream pubsub error",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{
|
||||
Message: err.Error(),
|
||||
},
|
||||
}
|
||||
return
|
||||
case result := <-relayReadyCh:
|
||||
// An async relay dial completed; swap in the
|
||||
// new relay channel.
|
||||
closeRelay()
|
||||
relayParts = result.parts
|
||||
relayCancel = result.cancel
|
||||
case <-reconnectCh:
|
||||
reconnectCh = nil
|
||||
// Re-check whether the chat is still running
|
||||
// on a remote worker before reconnecting.
|
||||
currentChat, chatErr := p.db.GetChatByID(mergedCtx, chatID)
|
||||
if chatErr == nil && currentChat.Status == database.ChatStatusRunning &&
|
||||
currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != p.workerID {
|
||||
openRelayAsync(currentChat.WorkerID.UUID)
|
||||
}
|
||||
case notify := <-notifications:
|
||||
// Handle different notification types
|
||||
if notify.AfterMessageID > 0 {
|
||||
// Read only new messages from DB.
|
||||
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: lastMessageID,
|
||||
})
|
||||
if err == nil {
|
||||
for _, msg := range messages {
|
||||
sdkMsg := db2sdk.ChatMessage(msg)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &sdkMsg,
|
||||
}:
|
||||
}
|
||||
lastMessageID = msg.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
if notify.Status != "" {
|
||||
status := database.ChatStatus(notify.Status)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
|
||||
}:
|
||||
}
|
||||
// Manage relay lifecycle based on status.
|
||||
if status == database.ChatStatusRunning && notify.WorkerID != "" {
|
||||
workerID, err := uuid.Parse(notify.WorkerID)
|
||||
if err == nil && workerID != p.workerID {
|
||||
openRelayAsync(workerID)
|
||||
} else if workerID == p.workerID {
|
||||
closeRelay()
|
||||
}
|
||||
} else {
|
||||
closeRelay()
|
||||
}
|
||||
}
|
||||
if notify.Error != "" {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{
|
||||
Message: notify.Error,
|
||||
},
|
||||
}:
|
||||
}
|
||||
}
|
||||
if notify.QueueUpdate {
|
||||
queued, err := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
||||
if err == nil {
|
||||
}
|
||||
return
|
||||
case notify := <-notifications:
|
||||
if notify.AfterMessageID > 0 {
|
||||
newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: lastMessageID,
|
||||
})
|
||||
if msgErr != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to get chat messages after pubsub notification",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(msgErr),
|
||||
)
|
||||
} else {
|
||||
for _, msg := range newMessages {
|
||||
sdkMsg := db2sdk.ChatMessage(msg)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
||||
ChatID: chatID,
|
||||
QueuedMessages: db2sdk.ChatQueuedMessages(queued),
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &sdkMsg,
|
||||
}:
|
||||
}
|
||||
}
|
||||
}
|
||||
case event, ok := <-localParts:
|
||||
if !ok {
|
||||
// Local parts channel closed, but continue with pubsub
|
||||
continue
|
||||
}
|
||||
// Only forward message_part events from local (durable events come via pubsub)
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
case event, ok := <-relayPartsCh:
|
||||
if !ok {
|
||||
relayParts = nil
|
||||
// Schedule reconnection instead of giving up.
|
||||
scheduleRelayReconnect()
|
||||
continue
|
||||
}
|
||||
// Only forward message_part events from relay (durable events come via pubsub)
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
lastMessageID = msg.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// No pubsub, just merge local parts.
|
||||
// localSnapshot was already included in initialSnapshot,
|
||||
// so only forward new events here.
|
||||
go func() {
|
||||
defer close(mergedEvents)
|
||||
for event := range localParts {
|
||||
if notify.Status != "" {
|
||||
status := database.ChatStatus(notify.Status)
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
|
||||
}:
|
||||
}
|
||||
// Notify enterprise relay manager if present.
|
||||
if statusNotifications != nil {
|
||||
workerID := uuid.Nil
|
||||
if notify.WorkerID != "" {
|
||||
if parsed, parseErr := uuid.Parse(notify.WorkerID); parseErr == nil {
|
||||
workerID = parsed
|
||||
}
|
||||
}
|
||||
select {
|
||||
case statusNotifications <- StatusNotification{Status: status, WorkerID: workerID}:
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if notify.Error != "" {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{
|
||||
Message: notify.Error,
|
||||
},
|
||||
}:
|
||||
}
|
||||
}
|
||||
if notify.QueueUpdate {
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
||||
if queueErr != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(queueErr),
|
||||
)
|
||||
} else {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
||||
ChatID: chatID,
|
||||
QueuedMessages: db2sdk.ChatQueuedMessages(queuedMsgs),
|
||||
}:
|
||||
}
|
||||
}
|
||||
}
|
||||
case event, ok := <-localParts:
|
||||
if !ok {
|
||||
localParts = nil
|
||||
// Local parts channel closed. If pubsub is
|
||||
// active we continue with pubsub-driven events.
|
||||
// Otherwise terminate.
|
||||
if !hasPubsub {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hasPubsub {
|
||||
// Only forward message_part events from local
|
||||
// (durable events come via pubsub).
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No pubsub: forward all event types.
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
case event, ok := <-relayEvents:
|
||||
if !ok {
|
||||
relayEvents = nil
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
@@ -1387,11 +1359,10 @@ func (p *Server) Subscribe(
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
if reconnectTimer != nil {
|
||||
reconnectTimer.Stop()
|
||||
if relayCleanup != nil {
|
||||
relayCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
return initialSnapshot, mergedEvents, cancel, true
|
||||
}
|
||||
|
||||
@@ -1733,6 +1704,15 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
return xerrors.Errorf("lock chat for release: %w", lockErr)
|
||||
}
|
||||
|
||||
// If another worker has already acquired this chat,
|
||||
// bail out — we must not overwrite their running
|
||||
// status or publish spurious events.
|
||||
if latestChat.Status == database.ChatStatusRunning &&
|
||||
latestChat.WorkerID.Valid &&
|
||||
latestChat.WorkerID.UUID != p.workerID {
|
||||
return errChatTakenByOtherWorker
|
||||
}
|
||||
|
||||
// If someone else already set the chat to pending (e.g.
|
||||
// the promote endpoint), don't overwrite it — just clear
|
||||
// the worker and let the processor pick it back up.
|
||||
@@ -1787,6 +1767,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
})
|
||||
return updateErr
|
||||
}, nil)
|
||||
if errors.Is(err, errChatTakenByOtherWorker) {
|
||||
// Another worker owns this chat now — skip all
|
||||
// post-TX side effects (status publish, pubsub,
|
||||
// web push) to avoid overwriting their state.
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error(cleanupCtx, "failed to release chat", slog.Error(err))
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -28,7 +27,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
@@ -1133,30 +1131,6 @@ func newTestServer(
|
||||
return server
|
||||
}
|
||||
|
||||
func newTestServerWithRelay(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps dbpubsub.Pubsub,
|
||||
replicaID uuid.UUID,
|
||||
provider chatd.RemotePartsProvider,
|
||||
) *chatd.Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: replicaID,
|
||||
Pubsub: ps,
|
||||
RemotePartsProvider: provider,
|
||||
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
return server
|
||||
}
|
||||
|
||||
func seedChatDependencies(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
@@ -1213,293 +1187,6 @@ func setOpenAIProviderBaseURL(
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
||||
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
||||
) {
|
||||
call := callCount.Add(1)
|
||||
ch := make(chan codersdk.ChatStreamEvent, 10)
|
||||
if call == 1 {
|
||||
// First relay: send a part then close to simulate a drop.
|
||||
ch <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first-relay"},
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
} else {
|
||||
// Second relay: send a different part, keep open.
|
||||
ch <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "second-relay"},
|
||||
},
|
||||
}
|
||||
// Don't close — keep alive so the subscriber stays connected.
|
||||
}
|
||||
return nil, ch, func() {}, nil
|
||||
}
|
||||
|
||||
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a chat and mark it as running on a remote worker.
|
||||
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "relay-reconnect",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Should get the first relay part.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
||||
event.MessagePart != nil &&
|
||||
event.MessagePart.Part.Text == "first-relay" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// After the first relay closes, a reconnection should happen and
|
||||
// deliver the second relay part.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
||||
event.MessagePart != nil &&
|
||||
event.MessagePart.Part.Text == "second-relay" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.GreaterOrEqual(t, int(callCount.Load()), 2)
|
||||
}
|
||||
|
||||
func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
dialStarted := make(chan struct{})
|
||||
dialContinue := make(chan struct{})
|
||||
|
||||
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
||||
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
||||
) {
|
||||
// Signal that the dial has started, then block until released.
|
||||
select {
|
||||
case <-dialStarted:
|
||||
default:
|
||||
close(dialStarted)
|
||||
}
|
||||
select {
|
||||
case <-dialContinue:
|
||||
case <-ctx.Done():
|
||||
return nil, nil, nil, ctx.Err()
|
||||
}
|
||||
ch := make(chan codersdk.ChatStreamEvent, 10)
|
||||
return nil, ch, func() {}, nil
|
||||
}
|
||||
|
||||
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a chat in pending status.
|
||||
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "relay-async-nonblock",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe before the chat is marked running so the relay opens
|
||||
// via pubsub notification (openRelayAsync path).
|
||||
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Now mark the chat as running on a remote worker. This publishes
|
||||
// a status notification which triggers openRelayAsync on the
|
||||
// subscriber.
|
||||
notify := coderdpubsub.ChatStreamNotifyMessage{
|
||||
Status: string(database.ChatStatusRunning),
|
||||
WorkerID: workerID.String(),
|
||||
}
|
||||
payload, err := json.Marshal(notify)
|
||||
require.NoError(t, err)
|
||||
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the relay dial to actually start (blocking in the
|
||||
// provider).
|
||||
select {
|
||||
case <-dialStarted:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for relay dial to start")
|
||||
}
|
||||
|
||||
// While the relay is still dialing (provider is blocked), publish
|
||||
// another status change. If openRelayAsync blocked the select loop
|
||||
// this event would never arrive.
|
||||
statusNotify := coderdpubsub.ChatStreamNotifyMessage{
|
||||
Status: string(database.ChatStatusWaiting),
|
||||
}
|
||||
statusPayload, err := json.Marshal(statusNotify)
|
||||
require.NoError(t, err)
|
||||
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The waiting status event should arrive promptly despite the
|
||||
// relay still dialing.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
return event.Type == codersdk.ChatStreamEventTypeStatus &&
|
||||
event.Status != nil &&
|
||||
event.Status.Status == codersdk.ChatStatusWaiting
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
// Unblock the relay dial so the test can clean up.
|
||||
close(dialContinue)
|
||||
}
|
||||
|
||||
func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
workerID := uuid.New()
|
||||
subscriberID := uuid.New()
|
||||
|
||||
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
||||
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
||||
) {
|
||||
// Return a non-empty snapshot with two parts.
|
||||
snapshot := []codersdk.ChatStreamEvent{
|
||||
{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-one"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-two"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ch := make(chan codersdk.ChatStreamEvent, 10)
|
||||
// Also send a live part after the snapshot.
|
||||
ch <- codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{
|
||||
Role: "assistant",
|
||||
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "live-part"},
|
||||
},
|
||||
}
|
||||
return snapshot, ch, func() {}, nil
|
||||
}
|
||||
|
||||
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a chat already running on a remote worker.
|
||||
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "relay-snapshot",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// The initial snapshot should contain the two relay snapshot parts.
|
||||
var snapshotTexts []string
|
||||
for _, event := range initialSnapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart && event.MessagePart != nil {
|
||||
snapshotTexts = append(snapshotTexts, event.MessagePart.Part.Text)
|
||||
}
|
||||
}
|
||||
require.Contains(t, snapshotTexts, "snap-one")
|
||||
require.Contains(t, snapshotTexts, "snap-two")
|
||||
|
||||
// The live part should arrive on the events channel.
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
||||
event.MessagePart != nil &&
|
||||
event.MessagePart.Part.Text == "live-part" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1694,7 +1694,7 @@ func TestStreamChat(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
events, closer, err := client.StreamChat(ctx, chat.ID)
|
||||
events, closer, err := client.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer closer.Close()
|
||||
|
||||
|
||||
+11
-11
@@ -239,9 +239,9 @@ type Options struct {
|
||||
SSHConfig codersdk.SSHConfigResponse
|
||||
|
||||
HTTPClient *http.Client
|
||||
// ChatRemotePartsProvider provides cross-replica message_part streaming.
|
||||
// ChatSubscribeFn provides cross-replica subscription merging.
|
||||
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
|
||||
ChatRemotePartsProvider chatd.RemotePartsProvider
|
||||
ChatSubscribeFn chatd.SubscribeFn
|
||||
|
||||
UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
StatsBatcher workspacestats.Batcher
|
||||
@@ -760,15 +760,15 @@ func New(options *Options) *API {
|
||||
api.agentProvider = stn
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
RemotePartsProvider: options.ChatRemotePartsProvider,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
})
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
|
||||
+16
-2
@@ -670,15 +670,29 @@ func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, e
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// StreamChatOptions are optional parameters for StreamChat.
|
||||
type StreamChatOptions struct {
|
||||
// AfterID limits the initial snapshot to messages created
|
||||
// after the given ID. This is useful for relay connections
|
||||
// that only need live message_part events and can skip the
|
||||
// full message history.
|
||||
AfterID *int64
|
||||
}
|
||||
|
||||
// StreamChat streams chat updates in real time.
|
||||
//
|
||||
// The returned channel includes initial snapshot events first, followed by
|
||||
// live updates. Callers must close the returned io.Closer to release the
|
||||
// websocket connection when done.
|
||||
func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID) (<-chan ChatStreamEvent, io.Closer, error) {
|
||||
func (c *Client) StreamChat(ctx context.Context, chatID uuid.UUID, opts *StreamChatOptions) (<-chan ChatStreamEvent, io.Closer, error) {
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s/stream", chatID)
|
||||
if opts != nil && opts.AfterID != nil {
|
||||
path += fmt.Sprintf("?after_id=%d", *opts.AfterID)
|
||||
}
|
||||
|
||||
conn, err := c.Dial(
|
||||
ctx,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/stream", chatID),
|
||||
path,
|
||||
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,575 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
osschatd "github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// RelaySourceHeader marks replica-relayed stream requests.
|
||||
const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
cookieHeader = "Cookie"
|
||||
)
|
||||
|
||||
// MultiReplicaSubscribeConfig holds the dependencies for multi-replica chat
|
||||
// subscription. ReplicaIDFn is called lazily because the
|
||||
// replica ID may not be known at construction time.
|
||||
//
|
||||
// DialerFn, when set, overrides the default WebSocket relay
|
||||
// dialer. This is used in tests to inject mock relay behavior
|
||||
// without requiring real HTTP servers.
|
||||
type MultiReplicaSubscribeConfig struct {
|
||||
ResolveReplicaAddress func(context.Context, uuid.UUID) (string, bool)
|
||||
ReplicaHTTPClient *http.Client
|
||||
ReplicaIDFn func() uuid.UUID
|
||||
DialerFn func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
snapshot []codersdk.ChatStreamEvent,
|
||||
parts <-chan codersdk.ChatStreamEvent,
|
||||
cancel func(),
|
||||
err error,
|
||||
)
|
||||
// Clock is used for creating timers. In production use
|
||||
// quartz.NewReal(); in tests use quartz.NewMock(t) to
|
||||
// control reconnect timing deterministically.
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
// dial returns the dialer function to use for relay connections.
|
||||
// If DialerFn is set (e.g. in tests), it takes precedence.
|
||||
// Otherwise, dialRelay is used with the real MultiReplicaSubscribeConfig dependencies.
|
||||
// Returns nil when no relay capability is configured.
|
||||
func (c MultiReplicaSubscribeConfig) dial() func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
if c.DialerFn != nil {
|
||||
return c.DialerFn
|
||||
}
|
||||
if c.ResolveReplicaAddress == nil {
|
||||
return nil
|
||||
}
|
||||
return func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
return dialRelay(ctx, chatID, workerID, requestHeader, c, c.clock())
|
||||
}
|
||||
}
|
||||
|
||||
// clock returns the quartz.Clock to use. Defaults to a real clock
|
||||
// when not set.
|
||||
func (c MultiReplicaSubscribeConfig) clock() quartz.Clock {
|
||||
if c.Clock != nil {
|
||||
return c.Clock
|
||||
}
|
||||
return quartz.NewReal()
|
||||
}
|
||||
|
||||
// NewMultiReplicaSubscribeFn returns a SubscribeFn that manages
|
||||
// relay connections to remote replicas and returns relay
|
||||
// message_part events only. OSS handles pubsub subscription,
|
||||
// message catch-up, queue updates, status forwarding, and local
|
||||
// parts merging.
|
||||
//
|
||||
//nolint:gocognit // Complexity is inherent to the multi-source merge loop.
|
||||
func NewMultiReplicaSubscribeFn(
|
||||
cfg MultiReplicaSubscribeConfig,
|
||||
) osschatd.SubscribeFn {
|
||||
return func(ctx context.Context, params osschatd.SubscribeFnParams) (<-chan codersdk.ChatStreamEvent, func()) {
|
||||
chatID := params.ChatID
|
||||
requestHeader := params.RequestHeader
|
||||
logger := params.Logger
|
||||
|
||||
var relayCancel func()
|
||||
var relayParts <-chan codersdk.ChatStreamEvent
|
||||
|
||||
// If the chat is currently running on a different worker
|
||||
// and we have a remote parts provider, open an initial
|
||||
// relay synchronously so the caller gets in-flight
|
||||
// message_part events right away.
|
||||
var initialRelaySnapshot []codersdk.ChatStreamEvent
|
||||
if params.Chat.Status == database.ChatStatusRunning &&
|
||||
params.Chat.WorkerID.Valid &&
|
||||
params.Chat.WorkerID.UUID != params.WorkerID &&
|
||||
cfg.dial() != nil {
|
||||
snapshot, parts, cancel, err := cfg.dial()(ctx, chatID, params.Chat.WorkerID.UUID, requestHeader)
|
||||
if err == nil {
|
||||
relayCancel = cancel
|
||||
relayParts = parts
|
||||
// Collect relay message_parts to forward at the
|
||||
// start of the merge goroutine.
|
||||
for _, event := range snapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
initialRelaySnapshot = append(initialRelaySnapshot, event)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Warn(ctx, "failed to open initial relay for chat stream",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge all event sources.
|
||||
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
||||
var allCancels []func()
|
||||
if relayCancel != nil {
|
||||
allCancels = append(allCancels, relayCancel)
|
||||
}
|
||||
|
||||
// Channel for async relay establishment.
|
||||
type relayResult struct {
|
||||
parts <-chan codersdk.ChatStreamEvent
|
||||
cancel func()
|
||||
workerID uuid.UUID // the worker this dial targeted
|
||||
}
|
||||
relayReadyCh := make(chan relayResult, 1)
|
||||
|
||||
// Per-dial context so in-flight dials can be canceled when
|
||||
// a new dial is initiated or the relay is closed.
|
||||
var dialCancel context.CancelFunc
|
||||
|
||||
// expectedWorkerID tracks which replica we expect the next
|
||||
// relay result to target. Stale results are discarded.
|
||||
var expectedWorkerID uuid.UUID
|
||||
|
||||
// Reconnect timer state.
|
||||
var reconnectTimer *quartz.Timer
|
||||
var reconnectCh <-chan time.Time
|
||||
|
||||
// Helper to close relay and stop any pending reconnect
|
||||
// timer.
|
||||
closeRelay := func() {
|
||||
// Cancel any in-flight dial goroutine first.
|
||||
if dialCancel != nil {
|
||||
dialCancel()
|
||||
dialCancel = nil
|
||||
}
|
||||
// Drain any buffered relay result from a canceled
|
||||
// dial.
|
||||
select {
|
||||
case result := <-relayReadyCh:
|
||||
if result.cancel != nil {
|
||||
result.cancel()
|
||||
}
|
||||
default:
|
||||
}
|
||||
expectedWorkerID = uuid.Nil
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
relayCancel = nil
|
||||
}
|
||||
relayParts = nil
|
||||
if reconnectTimer != nil {
|
||||
reconnectTimer.Stop()
|
||||
reconnectTimer = nil
|
||||
reconnectCh = nil
|
||||
}
|
||||
}
|
||||
|
||||
// openRelayAsync dials the remote replica in a background
|
||||
// goroutine and delivers the result on relayReadyCh so the
|
||||
// main select loop is never blocked by network I/O.
|
||||
openRelayAsync := func(workerID uuid.UUID) {
|
||||
if cfg.dial() == nil {
|
||||
return
|
||||
}
|
||||
closeRelay()
|
||||
// Create a per-dial context so this goroutine is
|
||||
// canceled if closeRelay() or openRelayAsync() is
|
||||
// called again before the dial completes.
|
||||
var dialCtx context.Context
|
||||
dialCtx, dialCancel = context.WithCancel(ctx)
|
||||
expectedWorkerID = workerID
|
||||
go func() {
|
||||
snapshot, parts, cancel, err := cfg.dial()(dialCtx, chatID, workerID, requestHeader)
|
||||
if err != nil {
|
||||
// Don't log context-canceled errors
|
||||
// since they are expected when a dial is
|
||||
// superseded by a newer one.
|
||||
if dialCtx.Err() == nil {
|
||||
logger.Warn(ctx, "failed to open relay for message parts",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("worker_id", workerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
// Send an empty result so the merge loop
|
||||
// can schedule a reconnect attempt.
|
||||
select {
|
||||
case relayReadyCh <- relayResult{workerID: workerID}:
|
||||
case <-dialCtx.Done():
|
||||
}
|
||||
return
|
||||
} // If the dial context was canceled while the
|
||||
// dial was in progress, discard the result to
|
||||
// avoid starting a wrappedParts goroutine for
|
||||
// a stale connection.
|
||||
if dialCtx.Err() != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
// Wrap the relay channel so snapshot parts
|
||||
// are delivered through the same channel as
|
||||
// live parts. This goroutine only forwards
|
||||
// events — it does not own the relay
|
||||
// lifecycle. When dialCtx is canceled it
|
||||
// simply returns, closing wrappedParts via
|
||||
// its defer. The cancel() is called by
|
||||
// whoever canceled dialCtx (closeRelay or
|
||||
// the send-fallback select below).
|
||||
wrappedParts := make(chan codersdk.ChatStreamEvent, 128)
|
||||
go func() {
|
||||
defer close(wrappedParts)
|
||||
for _, event := range snapshot {
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case wrappedParts <- event:
|
||||
case <-dialCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-parts:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case wrappedParts <- event:
|
||||
case <-dialCtx.Done():
|
||||
return
|
||||
}
|
||||
case <-dialCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel, workerID: workerID}:
|
||||
case <-dialCtx.Done():
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// scheduleRelayReconnect arms a short timer so the select
|
||||
// loop can re-check chat status and reopen the relay
|
||||
// without spinning in a tight loop.
|
||||
scheduleRelayReconnect := func() {
|
||||
if cfg.dial() == nil {
|
||||
return
|
||||
}
|
||||
if reconnectTimer != nil {
|
||||
reconnectTimer.Stop()
|
||||
}
|
||||
reconnectTimer = cfg.clock().NewTimer(500*time.Millisecond, "reconnect")
|
||||
reconnectCh = reconnectTimer.C
|
||||
}
|
||||
|
||||
statusNotifications := params.StatusNotifications
|
||||
go func() {
|
||||
defer close(mergedEvents)
|
||||
defer closeRelay()
|
||||
|
||||
// Forward any initial relay snapshot parts
|
||||
// collected synchronously above.
|
||||
for _, event := range initialRelaySnapshot {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
relayPartsCh := relayParts
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case result := <-relayReadyCh:
|
||||
// Discard stale relay results from a
|
||||
// previous dial that was superseded.
|
||||
if result.workerID != expectedWorkerID {
|
||||
if result.cancel != nil {
|
||||
result.cancel()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// A nil parts channel signals the dial
|
||||
// failed — schedule a retry.
|
||||
if result.parts == nil {
|
||||
scheduleRelayReconnect()
|
||||
continue
|
||||
}
|
||||
// An async relay dial completed; swap
|
||||
// in the new relay channel.
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
}
|
||||
relayParts = result.parts
|
||||
relayCancel = result.cancel
|
||||
case <-reconnectCh:
|
||||
reconnectCh = nil
|
||||
// Re-check whether the chat is still
|
||||
// running on a remote worker before
|
||||
// reconnecting.
|
||||
currentChat, chatErr := params.DB.GetChatByID(ctx, chatID)
|
||||
if chatErr != nil {
|
||||
logger.Warn(ctx, "failed to get chat for relay reconnect",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(chatErr),
|
||||
)
|
||||
// Retry on transient DB errors to
|
||||
// avoid permanently stalling the
|
||||
// stream.
|
||||
scheduleRelayReconnect()
|
||||
continue
|
||||
}
|
||||
if currentChat.Status == database.ChatStatusRunning &&
|
||||
currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != params.WorkerID {
|
||||
openRelayAsync(currentChat.WorkerID.UUID)
|
||||
}
|
||||
case sn, ok := <-statusNotifications:
|
||||
if !ok {
|
||||
statusNotifications = nil
|
||||
continue
|
||||
}
|
||||
if sn.Status == database.ChatStatusRunning && sn.WorkerID != uuid.Nil && sn.WorkerID != params.WorkerID {
|
||||
openRelayAsync(sn.WorkerID)
|
||||
} else {
|
||||
closeRelay()
|
||||
}
|
||||
case event, ok := <-relayPartsCh:
|
||||
if !ok {
|
||||
if relayCancel != nil {
|
||||
relayCancel()
|
||||
relayCancel = nil
|
||||
}
|
||||
relayParts = nil
|
||||
// Schedule reconnection instead of
|
||||
// giving up.
|
||||
scheduleRelayReconnect()
|
||||
continue
|
||||
}
|
||||
// Only forward message_part events from
|
||||
// relay.
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case mergedEvents <- event:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// The cancel function tears down the relay state
|
||||
// indirectly: the merge goroutine owns all relay state
|
||||
// (reconnectTimer, relayCancel, dialCancel, etc.) and
|
||||
// cleans it up via its defer closeRelay() when ctx is
|
||||
// canceled.
|
||||
cancel := func() {
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
return mergedEvents, cancel
|
||||
}
|
||||
}
|
||||
|
||||
// dialRelay opens a WebSocket relay connection to the replica
|
||||
// identified by workerID and returns a snapshot of buffered
|
||||
// message_part events plus a live channel of subsequent events.
|
||||
// It passes afterID=MaxInt64 so the remote replica skips the
|
||||
// full message history snapshot, since the relay only needs
|
||||
// live message_part events.
|
||||
func dialRelay(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
cfg MultiReplicaSubscribeConfig,
|
||||
clk quartz.Clock,
|
||||
) (
|
||||
snapshot []codersdk.ChatStreamEvent,
|
||||
parts <-chan codersdk.ChatStreamEvent,
|
||||
cancel func(),
|
||||
err error,
|
||||
) {
|
||||
address, ok := cfg.ResolveReplicaAddress(ctx, workerID)
|
||||
if !ok {
|
||||
return nil, nil, nil, xerrors.New("worker replica not found")
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err)
|
||||
}
|
||||
replicaID := cfg.ReplicaIDFn()
|
||||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||
sdkClient := codersdk.New(baseURL)
|
||||
sdkClient.HTTPClient = cfg.ReplicaHTTPClient
|
||||
sdkClient.SessionTokenProvider = relayHeaderTokenProvider{
|
||||
header: relayHeaders(requestHeader, replicaID),
|
||||
}
|
||||
sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID, &codersdk.StreamChatOptions{
|
||||
AfterID: ptr.Ref(int64(math.MaxInt64)),
|
||||
})
|
||||
if err != nil {
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err)
|
||||
}
|
||||
|
||||
snapshot = make([]codersdk.ChatStreamEvent, 0, 100)
|
||||
|
||||
// Wait briefly for the first event to handle the common
|
||||
// case where the remote side has buffered parts but hasn't
|
||||
// flushed them to the WebSocket yet.
|
||||
const drainTimeout = time.Second
|
||||
drainTimer := clk.NewTimer(drainTimeout, "drain")
|
||||
defer drainTimer.Stop()
|
||||
|
||||
drainInitial:
|
||||
for len(snapshot) < cap(snapshot) {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
_ = sourceStream.Close()
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err())
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
break drainInitial
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
snapshot = append(snapshot, event)
|
||||
// After getting the first event, switch to
|
||||
// non-blocking drain for remaining buffered events.
|
||||
drainTimer.Stop()
|
||||
drainTimer.Reset(0)
|
||||
case <-drainTimer.C:
|
||||
break drainInitial
|
||||
}
|
||||
}
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 128)
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
defer relayCancel()
|
||||
defer func() {
|
||||
_ = sourceStream.Close()
|
||||
}()
|
||||
|
||||
// No need to re-send snapshot events — they're
|
||||
// returned to the caller directly.
|
||||
for {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case events <- event:
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cancelFn := func() {
|
||||
relayCancel()
|
||||
_ = sourceStream.Close()
|
||||
}
|
||||
return snapshot, events, cancelFn, nil
|
||||
}
|
||||
|
||||
type relayHeaderTokenProvider struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption {
|
||||
return func(req *http.Request) {
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
||||
if opts.HTTPHeader == nil {
|
||||
opts.HTTPHeader = make(http.Header)
|
||||
}
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
opts.HTTPHeader.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) GetSessionToken() string {
|
||||
return p.header.Get(codersdk.SessionTokenHeader)
|
||||
}
|
||||
|
||||
func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header {
|
||||
header := make(http.Header)
|
||||
if source != nil {
|
||||
for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} {
|
||||
for _, value := range source.Values(key) {
|
||||
header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
header.Set(RelaySourceHeader, replicaID.String())
|
||||
return header
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,177 +0,0 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// RelaySourceHeader marks replica-relayed stream requests.
|
||||
const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
cookieHeader = "Cookie"
|
||||
)
|
||||
|
||||
// newRemotePartsProvider creates a RemotePartsProvider that dials a remote
|
||||
// replica's stream endpoint to fetch message_part events. It filters to only
|
||||
// forward message_part events since durable events come via pubsub.
|
||||
func newRemotePartsProvider(
|
||||
resolveReplicaAddress func(context.Context, uuid.UUID) (string, bool),
|
||||
replicaHTTPClient *http.Client,
|
||||
replicaID uuid.UUID,
|
||||
) chatd.RemotePartsProvider {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
address, ok := resolveReplicaAddress(ctx, workerID)
|
||||
if !ok {
|
||||
return nil, nil, nil, xerrors.New("worker replica not found")
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err)
|
||||
}
|
||||
relayCtx, relayCancel := context.WithCancel(ctx)
|
||||
sdkClient := codersdk.New(baseURL)
|
||||
sdkClient.HTTPClient = replicaHTTPClient
|
||||
sdkClient.SessionTokenProvider = relayHeaderTokenProvider{
|
||||
header: relayHeaders(requestHeader, replicaID),
|
||||
}
|
||||
sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID)
|
||||
if err != nil {
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err)
|
||||
}
|
||||
|
||||
snapshot := make([]codersdk.ChatStreamEvent, 0, 100)
|
||||
|
||||
// Wait briefly for the first event to handle the common
|
||||
// case where the remote side has buffered parts but hasn't
|
||||
// flushed them to the WebSocket yet.
|
||||
const drainTimeout = time.Second
|
||||
drainTimer := time.NewTimer(drainTimeout)
|
||||
defer drainTimer.Stop()
|
||||
|
||||
drainInitial:
|
||||
for len(snapshot) < cap(snapshot) {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
_ = sourceStream.Close()
|
||||
relayCancel()
|
||||
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err())
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
break drainInitial
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
snapshot = append(snapshot, event)
|
||||
// After getting the first event, switch to
|
||||
// non-blocking drain for remaining buffered events.
|
||||
drainTimer.Stop()
|
||||
drainTimer.Reset(0)
|
||||
case <-drainTimer.C:
|
||||
break drainInitial
|
||||
}
|
||||
}
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 128)
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
defer relayCancel()
|
||||
defer func() {
|
||||
_ = sourceStream.Close()
|
||||
}()
|
||||
|
||||
// No need to re-send snapshot events — they're
|
||||
// returned to the caller directly.
|
||||
for {
|
||||
select {
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
case event, ok := <-sourceEvents:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case events <- event:
|
||||
case <-relayCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
relayCancel()
|
||||
_ = sourceStream.Close()
|
||||
}
|
||||
return snapshot, events, cancel, nil
|
||||
}
|
||||
}
|
||||
|
||||
type relayHeaderTokenProvider struct {
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption {
|
||||
return func(req *http.Request) {
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
||||
if opts.HTTPHeader == nil {
|
||||
opts.HTTPHeader = make(http.Header)
|
||||
}
|
||||
for key, values := range p.header {
|
||||
for _, value := range values {
|
||||
opts.HTTPHeader.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p relayHeaderTokenProvider) GetSessionToken() string {
|
||||
return p.header.Get(codersdk.SessionTokenHeader)
|
||||
}
|
||||
|
||||
func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header {
|
||||
header := make(http.Header)
|
||||
if source != nil {
|
||||
for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} {
|
||||
for _, value := range source.Values(key) {
|
||||
header.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
header.Set(RelaySourceHeader, replicaID.String())
|
||||
return header
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID)
|
||||
firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer firstStream.Close()
|
||||
|
||||
@@ -151,7 +151,7 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
firstEvent := waitForStreamTextPart(ctx, t, firstEvents, firstChunkText)
|
||||
require.Equal(t, "assistant", firstEvent.MessagePart.Role)
|
||||
|
||||
secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID)
|
||||
secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer secondStream.Close()
|
||||
|
||||
@@ -277,7 +277,7 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
|
||||
// Subscribe on the local (worker) replica so the stream is
|
||||
// consumed and chunks flow through the pipeline.
|
||||
localEvents, localStream, err := localClient.StreamChat(ctx, chat.ID)
|
||||
localEvents, localStream, err := localClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer localStream.Close()
|
||||
|
||||
@@ -308,7 +308,7 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
// NOW connect the relay subscriber on the non-worker replica.
|
||||
// The relay must pick up all three buffered parts in its
|
||||
// initial snapshot via the drainInitial loop.
|
||||
relayEvents, relayStream, err := relayClient.StreamChat(ctx, chat.ID)
|
||||
relayEvents, relayStream, err := relayClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer relayStream.Close()
|
||||
|
||||
|
||||
+18
-29
@@ -45,6 +45,7 @@ import (
|
||||
agplusage "github.com/coder/coder/v2/coderd/usage"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/dbauthz"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
|
||||
@@ -191,8 +192,9 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
// This must happen before coderd initialization!
|
||||
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
|
||||
|
||||
// Wire up enterprise chat relay for cross-replica message_part streaming.
|
||||
// Must be set before coderd.New so the chat processor gets it.
|
||||
// Wire up enterprise chat subscription with cross-replica relay
|
||||
// and pubsub coordination. Must be set before coderd.New so the
|
||||
// chat processor receives it.
|
||||
replicaHTTPClient := replicaRelayHTTPClient(options.HTTPClient, meshTLSConfig)
|
||||
if replicaHTTPClient == nil {
|
||||
replicaHTTPClient = options.Options.HTTPClient
|
||||
@@ -200,33 +202,20 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
if replicaHTTPClient == nil {
|
||||
replicaHTTPClient = http.DefaultClient
|
||||
}
|
||||
// Use a closure that captures api by reference so it can access api.AGPL.ID
|
||||
// after coderd.New is called. The provider is only invoked when Subscribe
|
||||
// is called, which happens after initialization, so api.AGPL will be set.
|
||||
options.Options.ChatRemotePartsProvider = func(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
workerID uuid.UUID,
|
||||
requestHeader http.Header,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
error,
|
||||
) {
|
||||
// Get the replica ID from the API (will be set after coderd.New)
|
||||
replicaID := api.AGPL.ID
|
||||
if replicaID == uuid.Nil {
|
||||
// Fallback if somehow called before initialization
|
||||
replicaID = uuid.New()
|
||||
}
|
||||
provider := newRemotePartsProvider(
|
||||
resolveReplicaAddress,
|
||||
replicaHTTPClient,
|
||||
replicaID,
|
||||
)
|
||||
return provider(ctx, chatID, workerID, requestHeader)
|
||||
}
|
||||
// Use a closure that captures api by reference so it can access
|
||||
// api.AGPL.ID after coderd.New is called. The SubscribeFn is
|
||||
// only invoked from Subscribe, which happens after init.
|
||||
options.Options.ChatSubscribeFn = entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{
|
||||
ResolveReplicaAddress: resolveReplicaAddress,
|
||||
ReplicaHTTPClient: replicaHTTPClient,
|
||||
ReplicaIDFn: func() uuid.UUID {
|
||||
id := api.AGPL.ID
|
||||
if id == uuid.Nil {
|
||||
return uuid.New()
|
||||
}
|
||||
return id
|
||||
},
|
||||
})
|
||||
|
||||
api.AGPL = coderd.New(options.Options)
|
||||
defer func() {
|
||||
|
||||
Generated
+14
@@ -5432,6 +5432,20 @@ export interface StatsCollectionConfig {
|
||||
readonly usage_stats: UsageStatsConfig;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* StreamChatOptions are optional parameters for StreamChat.
|
||||
*/
|
||||
export interface StreamChatOptions {
|
||||
/**
|
||||
* AfterID limits the initial snapshot to messages created
|
||||
* after the given ID. This is useful for relay connections
|
||||
* that only need live message_part events and can skip the
|
||||
* full message history.
|
||||
*/
|
||||
readonly AfterID: number | null;
|
||||
}
|
||||
|
||||
// From codersdk/client.go
|
||||
/**
|
||||
* SubdomainAppSessionTokenCookie is the name of the cookie that stores an
|
||||
|
||||
Reference in New Issue
Block a user