mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
30d534b36b
## Summary Fixes a bug where interrupting a streaming chat and sending a new message left the relay connected to the wrong replica. Expanded into a broader refactor that cleanly separates concerns: - **OSS** owns pubsub subscription, message catch-up, queue updates, status forwarding, and local parts merging. - **Enterprise** (`enterprise/coderd/chatd`) only manages relay dialing, reconnection, and stale-dial discarding for cross-replica streaming. ## Architecture ### OSS `coderd/chatd/chatd.go` `Subscribe()` builds the initial snapshot then runs a single merge goroutine that handles: - Pubsub subscription for durable events (status, messages, queue, errors) - Message catch-up via `AfterMessageID` - Local `message_part` forwarding - Relay events from enterprise (when `SubscribeFn` is set) - Sends `StatusNotification` to enterprise so it can manage relay lifecycle Key types: - `SubscribeFn` — enterprise hook, returns relay-only events channel - `SubscribeFnParams` — `ChatID`, `Chat`, `WorkerID`, `StatusNotifications`, `RequestHeader`, `DB`, `Logger` - `StatusNotification` — `Status` + `WorkerID`, sent to enterprise on pubsub status changes ### Enterprise `enterprise/coderd/chatd/chatd.go` `NewMultiReplicaSubscribeFn(cfg MultiReplicaSubscribeConfig)` returns a `SubscribeFn` that: - Opens an initial synchronous relay if the chat is running on a remote worker - Reads `StatusNotifications` from OSS to open/close relay connections - Handles async dial, reconnect timers, stale-dial discarding - Returns only relay `message_part` events ## Bug fixes ### Original bug: stale relay dial after interrupt `openRelayAsync` goroutines used `mergedCtx` (subscription-level), not a per-dial context. `closeRelay()` could not cancel in-flight dials. When the user interrupts and a new replica picks up the chat, the old dial goroutine could complete after the new one and deliver a stale `relayResult`. **Fix**: per-dial `dialCtx`/`dialCancel`, `expectedWorkerID` tracking, `workerID` on `relayResult`. `closeRelay()` cancels the dial context and drains `relayReadyCh`. Merge loop rejects mismatched worker IDs. ### Additional fixes - `statusNotifications` send-on-closed-channel race — goroutine now owns `close()` via defer - Enterprise spin-loop on `StatusNotifications` close — two-value receive with nil-out - `hasPubsub` set from `p.pubsub != nil` instead of subscription success — now tracks actual subscription result - `lastMessageID` not initialized from `afterMessageID` — caused duplicate messages on catch-up - `wrappedParts` goroutine leaked remote connection on `dialCtx` cancel - `closeRelay()` did not drain `relayReadyCh` - `setChatWaiting` race with `SendMessage(Interrupt)` — wrapped in `InTx` - `processChat` post-TX side effects fired when chat was taken by another worker — added `errChatTakenByOtherWorker` sentinel - Cancel closure data race on `reconnectTimer` - Bare blocking send on pubsub error path - `localParts` hot-spin after channel close - No-pubsub branch dropped relay events and initial snapshot - Failed relay dial caused permanent stall (no reconnect retry) - DB error during reconnect timer caused permanent stall - `time.NewTimer` replaced with `quartz.Clock` for testable timing ## Tests 9 enterprise tests covering: - Relay reconnect on drop (mock clock) - Async dial does not block merge loop - Relay snapshot delivery - Stale dial discarded after interrupt - Cancel during in-flight dial - Running-to-running worker switch - Failed dial retries (mock clock) - Local worker closes relay - Multiple consecutive reconnects (mock clock) All pass with `-race`.
576 lines
16 KiB
Go
576 lines
16 KiB
Go
package chatd
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
osschatd "github.com/coder/coder/v2/coderd/chatd"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/quartz"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
// RelaySourceHeader marks replica-relayed stream requests.
|
|
const RelaySourceHeader = "X-Coder-Relay-Source-Replica"
|
|
|
|
const (
|
|
authorizationHeader = "Authorization"
|
|
cookieHeader = "Cookie"
|
|
)
|
|
|
|
// MultiReplicaSubscribeConfig holds the dependencies for multi-replica chat
|
|
// subscription. ReplicaIDFn is called lazily because the
|
|
// replica ID may not be known at construction time.
|
|
//
|
|
// DialerFn, when set, overrides the default WebSocket relay
|
|
// dialer. This is used in tests to inject mock relay behavior
|
|
// without requiring real HTTP servers.
|
|
type MultiReplicaSubscribeConfig struct {
|
|
ResolveReplicaAddress func(context.Context, uuid.UUID) (string, bool)
|
|
ReplicaHTTPClient *http.Client
|
|
ReplicaIDFn func() uuid.UUID
|
|
DialerFn func(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
workerID uuid.UUID,
|
|
requestHeader http.Header,
|
|
) (
|
|
snapshot []codersdk.ChatStreamEvent,
|
|
parts <-chan codersdk.ChatStreamEvent,
|
|
cancel func(),
|
|
err error,
|
|
)
|
|
// Clock is used for creating timers. In production use
|
|
// quartz.NewReal(); in tests use quartz.NewMock(t) to
|
|
// control reconnect timing deterministically.
|
|
Clock quartz.Clock
|
|
}
|
|
|
|
// dial returns the dialer function to use for relay connections.
|
|
// If DialerFn is set (e.g. in tests), it takes precedence.
|
|
// Otherwise, dialRelay is used with the real MultiReplicaSubscribeConfig dependencies.
|
|
// Returns nil when no relay capability is configured.
|
|
func (c MultiReplicaSubscribeConfig) dial() func(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
workerID uuid.UUID,
|
|
requestHeader http.Header,
|
|
) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
error,
|
|
) {
|
|
if c.DialerFn != nil {
|
|
return c.DialerFn
|
|
}
|
|
if c.ResolveReplicaAddress == nil {
|
|
return nil
|
|
}
|
|
return func(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
workerID uuid.UUID,
|
|
requestHeader http.Header,
|
|
) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
error,
|
|
) {
|
|
return dialRelay(ctx, chatID, workerID, requestHeader, c, c.clock())
|
|
}
|
|
}
|
|
|
|
// clock returns the quartz.Clock to use. Defaults to a real clock
|
|
// when not set.
|
|
func (c MultiReplicaSubscribeConfig) clock() quartz.Clock {
|
|
if c.Clock != nil {
|
|
return c.Clock
|
|
}
|
|
return quartz.NewReal()
|
|
}
|
|
|
|
// NewMultiReplicaSubscribeFn returns a SubscribeFn that manages
|
|
// relay connections to remote replicas and returns relay
|
|
// message_part events only. OSS handles pubsub subscription,
|
|
// message catch-up, queue updates, status forwarding, and local
|
|
// parts merging.
|
|
//
|
|
//nolint:gocognit // Complexity is inherent to the multi-source merge loop.
|
|
func NewMultiReplicaSubscribeFn(
|
|
cfg MultiReplicaSubscribeConfig,
|
|
) osschatd.SubscribeFn {
|
|
return func(ctx context.Context, params osschatd.SubscribeFnParams) (<-chan codersdk.ChatStreamEvent, func()) {
|
|
chatID := params.ChatID
|
|
requestHeader := params.RequestHeader
|
|
logger := params.Logger
|
|
|
|
var relayCancel func()
|
|
var relayParts <-chan codersdk.ChatStreamEvent
|
|
|
|
// If the chat is currently running on a different worker
|
|
// and we have a remote parts provider, open an initial
|
|
// relay synchronously so the caller gets in-flight
|
|
// message_part events right away.
|
|
var initialRelaySnapshot []codersdk.ChatStreamEvent
|
|
if params.Chat.Status == database.ChatStatusRunning &&
|
|
params.Chat.WorkerID.Valid &&
|
|
params.Chat.WorkerID.UUID != params.WorkerID &&
|
|
cfg.dial() != nil {
|
|
snapshot, parts, cancel, err := cfg.dial()(ctx, chatID, params.Chat.WorkerID.UUID, requestHeader)
|
|
if err == nil {
|
|
relayCancel = cancel
|
|
relayParts = parts
|
|
// Collect relay message_parts to forward at the
|
|
// start of the merge goroutine.
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
initialRelaySnapshot = append(initialRelaySnapshot, event)
|
|
}
|
|
}
|
|
} else {
|
|
logger.Warn(ctx, "failed to open initial relay for chat stream",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Merge all event sources.
|
|
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
|
var allCancels []func()
|
|
if relayCancel != nil {
|
|
allCancels = append(allCancels, relayCancel)
|
|
}
|
|
|
|
// Channel for async relay establishment.
|
|
type relayResult struct {
|
|
parts <-chan codersdk.ChatStreamEvent
|
|
cancel func()
|
|
workerID uuid.UUID // the worker this dial targeted
|
|
}
|
|
relayReadyCh := make(chan relayResult, 1)
|
|
|
|
// Per-dial context so in-flight dials can be canceled when
|
|
// a new dial is initiated or the relay is closed.
|
|
var dialCancel context.CancelFunc
|
|
|
|
// expectedWorkerID tracks which replica we expect the next
|
|
// relay result to target. Stale results are discarded.
|
|
var expectedWorkerID uuid.UUID
|
|
|
|
// Reconnect timer state.
|
|
var reconnectTimer *quartz.Timer
|
|
var reconnectCh <-chan time.Time
|
|
|
|
// Helper to close relay and stop any pending reconnect
|
|
// timer.
|
|
closeRelay := func() {
|
|
// Cancel any in-flight dial goroutine first.
|
|
if dialCancel != nil {
|
|
dialCancel()
|
|
dialCancel = nil
|
|
}
|
|
// Drain any buffered relay result from a canceled
|
|
// dial.
|
|
select {
|
|
case result := <-relayReadyCh:
|
|
if result.cancel != nil {
|
|
result.cancel()
|
|
}
|
|
default:
|
|
}
|
|
expectedWorkerID = uuid.Nil
|
|
if relayCancel != nil {
|
|
relayCancel()
|
|
relayCancel = nil
|
|
}
|
|
relayParts = nil
|
|
if reconnectTimer != nil {
|
|
reconnectTimer.Stop()
|
|
reconnectTimer = nil
|
|
reconnectCh = nil
|
|
}
|
|
}
|
|
|
|
// openRelayAsync dials the remote replica in a background
|
|
// goroutine and delivers the result on relayReadyCh so the
|
|
// main select loop is never blocked by network I/O.
|
|
openRelayAsync := func(workerID uuid.UUID) {
|
|
if cfg.dial() == nil {
|
|
return
|
|
}
|
|
closeRelay()
|
|
// Create a per-dial context so this goroutine is
|
|
// canceled if closeRelay() or openRelayAsync() is
|
|
// called again before the dial completes.
|
|
var dialCtx context.Context
|
|
dialCtx, dialCancel = context.WithCancel(ctx)
|
|
expectedWorkerID = workerID
|
|
go func() {
|
|
snapshot, parts, cancel, err := cfg.dial()(dialCtx, chatID, workerID, requestHeader)
|
|
if err != nil {
|
|
// Don't log context-canceled errors
|
|
// since they are expected when a dial is
|
|
// superseded by a newer one.
|
|
if dialCtx.Err() == nil {
|
|
logger.Warn(ctx, "failed to open relay for message parts",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("worker_id", workerID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
// Send an empty result so the merge loop
|
|
// can schedule a reconnect attempt.
|
|
select {
|
|
case relayReadyCh <- relayResult{workerID: workerID}:
|
|
case <-dialCtx.Done():
|
|
}
|
|
return
|
|
} // If the dial context was canceled while the
|
|
// dial was in progress, discard the result to
|
|
// avoid starting a wrappedParts goroutine for
|
|
// a stale connection.
|
|
if dialCtx.Err() != nil {
|
|
cancel()
|
|
return
|
|
}
|
|
// Wrap the relay channel so snapshot parts
|
|
// are delivered through the same channel as
|
|
// live parts. This goroutine only forwards
|
|
// events — it does not own the relay
|
|
// lifecycle. When dialCtx is canceled it
|
|
// simply returns, closing wrappedParts via
|
|
// its defer. The cancel() is called by
|
|
// whoever canceled dialCtx (closeRelay or
|
|
// the send-fallback select below).
|
|
wrappedParts := make(chan codersdk.ChatStreamEvent, 128)
|
|
go func() {
|
|
defer close(wrappedParts)
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
select {
|
|
case wrappedParts <- event:
|
|
case <-dialCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
for {
|
|
select {
|
|
case event, ok := <-parts:
|
|
if !ok {
|
|
return
|
|
}
|
|
select {
|
|
case wrappedParts <- event:
|
|
case <-dialCtx.Done():
|
|
return
|
|
}
|
|
case <-dialCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
select {
|
|
case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel, workerID: workerID}:
|
|
case <-dialCtx.Done():
|
|
cancel()
|
|
}
|
|
}()
|
|
}
|
|
|
|
// scheduleRelayReconnect arms a short timer so the select
|
|
// loop can re-check chat status and reopen the relay
|
|
// without spinning in a tight loop.
|
|
scheduleRelayReconnect := func() {
|
|
if cfg.dial() == nil {
|
|
return
|
|
}
|
|
if reconnectTimer != nil {
|
|
reconnectTimer.Stop()
|
|
}
|
|
reconnectTimer = cfg.clock().NewTimer(500*time.Millisecond, "reconnect")
|
|
reconnectCh = reconnectTimer.C
|
|
}
|
|
|
|
statusNotifications := params.StatusNotifications
|
|
go func() {
|
|
defer close(mergedEvents)
|
|
defer closeRelay()
|
|
|
|
// Forward any initial relay snapshot parts
|
|
// collected synchronously above.
|
|
for _, event := range initialRelaySnapshot {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
|
|
for {
|
|
relayPartsCh := relayParts
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case result := <-relayReadyCh:
|
|
// Discard stale relay results from a
|
|
// previous dial that was superseded.
|
|
if result.workerID != expectedWorkerID {
|
|
if result.cancel != nil {
|
|
result.cancel()
|
|
}
|
|
continue
|
|
}
|
|
// A nil parts channel signals the dial
|
|
// failed — schedule a retry.
|
|
if result.parts == nil {
|
|
scheduleRelayReconnect()
|
|
continue
|
|
}
|
|
// An async relay dial completed; swap
|
|
// in the new relay channel.
|
|
if relayCancel != nil {
|
|
relayCancel()
|
|
}
|
|
relayParts = result.parts
|
|
relayCancel = result.cancel
|
|
case <-reconnectCh:
|
|
reconnectCh = nil
|
|
// Re-check whether the chat is still
|
|
// running on a remote worker before
|
|
// reconnecting.
|
|
currentChat, chatErr := params.DB.GetChatByID(ctx, chatID)
|
|
if chatErr != nil {
|
|
logger.Warn(ctx, "failed to get chat for relay reconnect",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(chatErr),
|
|
)
|
|
// Retry on transient DB errors to
|
|
// avoid permanently stalling the
|
|
// stream.
|
|
scheduleRelayReconnect()
|
|
continue
|
|
}
|
|
if currentChat.Status == database.ChatStatusRunning &&
|
|
currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != params.WorkerID {
|
|
openRelayAsync(currentChat.WorkerID.UUID)
|
|
}
|
|
case sn, ok := <-statusNotifications:
|
|
if !ok {
|
|
statusNotifications = nil
|
|
continue
|
|
}
|
|
if sn.Status == database.ChatStatusRunning && sn.WorkerID != uuid.Nil && sn.WorkerID != params.WorkerID {
|
|
openRelayAsync(sn.WorkerID)
|
|
} else {
|
|
closeRelay()
|
|
}
|
|
case event, ok := <-relayPartsCh:
|
|
if !ok {
|
|
if relayCancel != nil {
|
|
relayCancel()
|
|
relayCancel = nil
|
|
}
|
|
relayParts = nil
|
|
// Schedule reconnection instead of
|
|
// giving up.
|
|
scheduleRelayReconnect()
|
|
continue
|
|
}
|
|
// Only forward message_part events from
|
|
// relay.
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// The cancel function tears down the relay state
|
|
// indirectly: the merge goroutine owns all relay state
|
|
// (reconnectTimer, relayCancel, dialCancel, etc.) and
|
|
// cleans it up via its defer closeRelay() when ctx is
|
|
// canceled.
|
|
cancel := func() {
|
|
for _, cancelFn := range allCancels {
|
|
if cancelFn != nil {
|
|
cancelFn()
|
|
}
|
|
}
|
|
}
|
|
return mergedEvents, cancel
|
|
}
|
|
}
|
|
|
|
// dialRelay opens a WebSocket relay connection to the replica
|
|
// identified by workerID and returns a snapshot of buffered
|
|
// message_part events plus a live channel of subsequent events.
|
|
// It passes afterID=MaxInt64 so the remote replica skips the
|
|
// full message history snapshot, since the relay only needs
|
|
// live message_part events.
|
|
func dialRelay(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
workerID uuid.UUID,
|
|
requestHeader http.Header,
|
|
cfg MultiReplicaSubscribeConfig,
|
|
clk quartz.Clock,
|
|
) (
|
|
snapshot []codersdk.ChatStreamEvent,
|
|
parts <-chan codersdk.ChatStreamEvent,
|
|
cancel func(),
|
|
err error,
|
|
) {
|
|
address, ok := cfg.ResolveReplicaAddress(ctx, workerID)
|
|
if !ok {
|
|
return nil, nil, nil, xerrors.New("worker replica not found")
|
|
}
|
|
|
|
baseURL, err := url.Parse(address)
|
|
if err != nil {
|
|
return nil, nil, nil, xerrors.Errorf("parse relay address %q: %w", address, err)
|
|
}
|
|
replicaID := cfg.ReplicaIDFn()
|
|
relayCtx, relayCancel := context.WithCancel(ctx)
|
|
sdkClient := codersdk.New(baseURL)
|
|
sdkClient.HTTPClient = cfg.ReplicaHTTPClient
|
|
sdkClient.SessionTokenProvider = relayHeaderTokenProvider{
|
|
header: relayHeaders(requestHeader, replicaID),
|
|
}
|
|
sourceEvents, sourceStream, err := sdkClient.StreamChat(relayCtx, chatID, &codersdk.StreamChatOptions{
|
|
AfterID: ptr.Ref(int64(math.MaxInt64)),
|
|
})
|
|
if err != nil {
|
|
relayCancel()
|
|
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", err)
|
|
}
|
|
|
|
snapshot = make([]codersdk.ChatStreamEvent, 0, 100)
|
|
|
|
// Wait briefly for the first event to handle the common
|
|
// case where the remote side has buffered parts but hasn't
|
|
// flushed them to the WebSocket yet.
|
|
const drainTimeout = time.Second
|
|
drainTimer := clk.NewTimer(drainTimeout, "drain")
|
|
defer drainTimer.Stop()
|
|
|
|
drainInitial:
|
|
for len(snapshot) < cap(snapshot) {
|
|
select {
|
|
case <-relayCtx.Done():
|
|
_ = sourceStream.Close()
|
|
relayCancel()
|
|
return nil, nil, nil, xerrors.Errorf("dial relay stream: %w", relayCtx.Err())
|
|
case event, ok := <-sourceEvents:
|
|
if !ok {
|
|
break drainInitial
|
|
}
|
|
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
|
continue
|
|
}
|
|
snapshot = append(snapshot, event)
|
|
// After getting the first event, switch to
|
|
// non-blocking drain for remaining buffered events.
|
|
drainTimer.Stop()
|
|
drainTimer.Reset(0)
|
|
case <-drainTimer.C:
|
|
break drainInitial
|
|
}
|
|
}
|
|
|
|
events := make(chan codersdk.ChatStreamEvent, 128)
|
|
|
|
go func() {
|
|
defer close(events)
|
|
defer relayCancel()
|
|
defer func() {
|
|
_ = sourceStream.Close()
|
|
}()
|
|
|
|
// No need to re-send snapshot events — they're
|
|
// returned to the caller directly.
|
|
for {
|
|
select {
|
|
case <-relayCtx.Done():
|
|
return
|
|
case event, ok := <-sourceEvents:
|
|
if !ok {
|
|
return
|
|
}
|
|
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
|
continue
|
|
}
|
|
select {
|
|
case events <- event:
|
|
case <-relayCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
cancelFn := func() {
|
|
relayCancel()
|
|
_ = sourceStream.Close()
|
|
}
|
|
return snapshot, events, cancelFn, nil
|
|
}
|
|
|
|
type relayHeaderTokenProvider struct {
|
|
header http.Header
|
|
}
|
|
|
|
func (p relayHeaderTokenProvider) AsRequestOption() codersdk.RequestOption {
|
|
return func(req *http.Request) {
|
|
for key, values := range p.header {
|
|
for _, value := range values {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p relayHeaderTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
|
if opts.HTTPHeader == nil {
|
|
opts.HTTPHeader = make(http.Header)
|
|
}
|
|
for key, values := range p.header {
|
|
for _, value := range values {
|
|
opts.HTTPHeader.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p relayHeaderTokenProvider) GetSessionToken() string {
|
|
return p.header.Get(codersdk.SessionTokenHeader)
|
|
}
|
|
|
|
func relayHeaders(source http.Header, replicaID uuid.UUID) http.Header {
|
|
header := make(http.Header)
|
|
if source != nil {
|
|
for _, key := range []string{codersdk.SessionTokenHeader, authorizationHeader, cookieHeader} {
|
|
for _, value := range source.Values(key) {
|
|
header.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
header.Set(RelaySourceHeader, replicaID.String())
|
|
return header
|
|
}
|