Files
coder/enterprise/coderd/x/chatd/chatd_test.go
T

1248 lines
38 KiB
Go

package chatd_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
osschatd "github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
entchatd "github.com/coder/coder/v2/enterprise/coderd/x/chatd"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func newTestServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
replicaID uuid.UUID,
dialer func(
ctx context.Context,
chatID uuid.UUID,
workerID uuid.UUID,
requestHeader http.Header,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
error,
),
clock quartz.Clock,
) *osschatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := osschatd.New(osschatd.Config{
Logger: logger,
Database: db,
ReplicaID: replicaID,
Pubsub: ps,
SubscribeFn: entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{DialerFn: dialer, Clock: clock}),
PendingChatAcquireInterval: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
func newActiveWorkerServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
replicaID uuid.UUID,
) *osschatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := osschatd.New(osschatd.Config{
Logger: logger,
Database: db,
ReplicaID: replicaID,
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
// seedChatDependencies creates a user and chat model config in the
// database for use in relay tests.
func seedChatDependencies(
ctx context.Context,
t *testing.T,
db database.Store,
) (database.User, database.ChatModelConfig) {
t.Helper()
safetyNet := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write([]byte(`{"error":{"message":"unexpected OpenAI request in chatd relay test safety net"}}`))
}))
t.Cleanup(safetyNet.Close)
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: safetyNet.URL,
CentralApiKeyEnabled: true,
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, model
}
func seedWaitingChat(
ctx context.Context,
t *testing.T,
db database.Store,
user database.User,
model database.ChatModelConfig,
title string,
) database.Chat {
t.Helper()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: title,
MCPServerIDs: []uuid.UUID{},
})
require.NoError(t, err)
return chat
}
func seedRemoteRunningChat(
ctx context.Context,
t *testing.T,
db database.Store,
user database.User,
model database.ChatModelConfig,
workerID uuid.UUID,
title string,
) database.Chat {
t.Helper()
chat := seedWaitingChat(ctx, t, db, user, model, title)
now := time.Now()
chat, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: now, Valid: true},
HeartbeatAt: sql.NullTime{Time: now, Valid: true},
})
require.NoError(t, err)
return chat
}
func setOpenAIProviderBaseURL(
ctx context.Context,
t *testing.T,
db database.Store,
baseURL string,
) {
t.Helper()
provider, err := db.GetChatProviderByProvider(ctx, "openai")
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
CentralApiKeyEnabled: true,
AllowUserApiKey: false,
AllowCentralApiKeyFallback: false,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
})
require.NoError(t, err)
}
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
ch := make(chan codersdk.ChatStreamEvent, 10)
if call == 1 {
// First relay: send a part then close to simulate a drop.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("first-relay"),
},
}
close(ch)
} else {
// Second relay: send a different part, keep open.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("second-relay"),
},
}
// Don't close — keep alive so the subscriber stays connected.
}
return nil, ch, func() {}, nil
}
mclk := quartz.NewMock(t)
// Trap the reconnect timer so we can fire it deterministically
// instead of waiting real time.
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat := seedRemoteRunningChat(ctx, t, db, user, model, workerID, "relay-reconnect")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Should get the first relay part.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "first-relay" {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// Wait for the reconnect timer to be created after the relay
// drop, then advance the mock clock to fire it immediately.
trapReconnect.MustWait(ctx).MustRelease(ctx)
mclk.Advance(500 * time.Millisecond).MustWait(ctx)
// After the first relay closes, the reconnection should deliver
// the second relay part.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "second-relay" {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.GreaterOrEqual(t, int(callCount.Load()), 2)
}
func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
dialStarted := make(chan struct{})
dialContinue := make(chan struct{})
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
// Signal that the dial has started, then block until released.
select {
case <-dialStarted:
default:
close(dialStarted)
}
select {
case <-dialContinue:
case <-ctx.Done():
return nil, nil, nil, ctx.Err()
}
ch := make(chan codersdk.ChatStreamEvent, 10)
return nil, ch, func() {}, nil
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Seed a waiting chat so Subscribe does not trigger a synchronous
// relay.
chat := seedWaitingChat(ctx, t, db, user, model, "relay-async-nonblock")
// Subscribe before the chat is marked running so the relay opens
// via pubsub notification (openRelayAsync path).
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Now mark the chat as running on a remote worker. This publishes
// a status notification which triggers openRelayAsync on the
// subscriber.
notify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: workerID.String(),
}
payload, err := json.Marshal(notify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
require.NoError(t, err)
// Wait for the relay dial to actually start (blocking in the
// provider).
select {
case <-dialStarted:
case <-ctx.Done():
t.Fatal("timed out waiting for relay dial to start")
}
// While the relay is still dialing (provider is blocked), publish
// another status change. If openRelayAsync blocked the select loop
// this event would never arrive.
statusNotify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusWaiting),
}
statusPayload, err := json.Marshal(statusNotify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload)
require.NoError(t, err)
// The waiting status event should arrive promptly despite the
// relay still dialing.
require.Eventually(t, func() bool {
select {
case event := <-events:
return event.Type == codersdk.ChatStreamEventTypeStatus &&
event.Status != nil &&
event.Status.Status == codersdk.ChatStatusWaiting
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
// Unblock the relay dial so the test can clean up.
close(dialContinue)
}
func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
// Return a non-empty snapshot with two parts.
snapshot := []codersdk.ChatStreamEvent{
{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("snap-one"),
},
},
{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("snap-two"),
},
},
}
ch := make(chan codersdk.ChatStreamEvent, 10)
// Also send a live part after the snapshot.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("live-part"),
},
}
return snapshot, ch, func() {}, nil
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat := seedRemoteRunningChat(ctx, t, db, user, model, workerID, "relay-snapshot")
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// The relay snapshot parts are forwarded through the events
// channel by the enterprise SubscribeFn. Collect them along
// with the live part.
var receivedTexts []string
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil {
receivedTexts = append(receivedTexts, event.MessagePart.Part.Text)
}
// We expect snap-one, snap-two, and live-part.
return len(receivedTexts) >= 3
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts)
// The initial snapshot should still contain the status event
// from the OSS preamble.
var hasStatus bool
for _, event := range initialSnapshot {
if event.Type == codersdk.ChatStreamEventTypeStatus {
hasStatus = true
}
}
require.True(t, hasStatus, "initial snapshot should contain status event")
}
func TestSubscribeRetryEventAcrossInstances(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var streamCalls atomic.Int32
firstStreamStarted := make(chan struct{})
allowFirstFailure := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("retry-across-instances")
}
if streamCalls.Add(1) == 1 {
select {
case <-firstStreamStarted:
default:
close(firstStreamStarted)
}
<-allowFirstFailure
return chattest.OpenAIRateLimitResponse()
}
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...)
})
worker := newActiveWorkerServer(t, db, ps, workerID)
subscriber := newTestServer(t, db, ps, subscriberID, func(
ctx context.Context,
chatID uuid.UUID,
targetWorkerID uuid.UUID,
requestHeader http.Header,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
error,
) {
if targetWorkerID != workerID {
return nil, nil, nil, xerrors.Errorf("unexpected relay target %s", targetWorkerID)
}
snapshot, events, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64)
if !ok {
return nil, nil, nil, xerrors.New("worker subscribe failed")
}
return snapshot, events, cancel, nil
}, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := worker.CreateChat(ctx, osschatd.CreateOptions{
OwnerID: user.ID,
Title: "retry-across-instances",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
require.Eventually(t, func() bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning &&
fromDB.WorkerID.Valid && fromDB.WorkerID.UUID == workerID
}, testutil.WaitMedium, testutil.IntervalFast)
select {
case <-firstStreamStarted:
case <-ctx.Done():
t.Fatal("timed out waiting for first streaming attempt")
}
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
defer cancel()
close(allowFirstFailure)
var retryEvent *codersdk.ChatStreamRetry
var waitingSeen bool
var waitingBeforeRetry bool
var assistantMessageBeforeRetry bool
require.Eventually(t, func() bool {
select {
case event, ok := <-events:
if !ok {
return false
}
switch event.Type {
case codersdk.ChatStreamEventTypeRetry:
retryEvent = event.Retry
case codersdk.ChatStreamEventTypeMessage:
if event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleAssistant {
if retryEvent == nil {
assistantMessageBeforeRetry = true
}
}
case codersdk.ChatStreamEventTypeStatus:
if event.Status != nil && event.Status.Status == codersdk.ChatStatusWaiting {
if retryEvent == nil {
waitingBeforeRetry = true
}
waitingSeen = true
}
}
return retryEvent != nil && waitingSeen
default:
return false
}
}, testutil.WaitLong, testutil.IntervalFast)
require.NotNil(t, retryEvent)
require.Equal(t, 1, retryEvent.Attempt)
require.Greater(t, retryEvent.DelayMs, int64(0))
require.Equal(t, "rate_limit", retryEvent.Kind)
require.Equal(t, "openai", retryEvent.Provider)
require.Equal(t, 429, retryEvent.StatusCode)
require.Contains(t, retryEvent.Error, "rate limiting requests")
require.False(t, assistantMessageBeforeRetry)
require.False(t, waitingBeforeRetry)
require.GreaterOrEqual(t, streamCalls.Load(), int32(2))
}
// TestSubscribeRelayStaleDialDiscardedAfterInterrupt verifies that when a
// user interrupts a streaming chat and sends a new message (which gets
// picked up by a different replica), an in-flight relay dial to the
// OLD replica is canceled/discarded and the relay connects to the
// NEW replica correctly.
func TestSubscribeRelayStaleDialDiscardedAfterInterrupt(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
oldWorkerID := uuid.New()
newWorkerID := uuid.New()
subscriberID := uuid.New()
// Gate to hold the first dial until we're ready.
firstDialStarted := make(chan struct{})
releaseFirstDial := make(chan struct{})
var callCount atomic.Int32
provider := func(ctx context.Context, _ uuid.UUID, workerID uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
ch := make(chan codersdk.ChatStreamEvent, 10)
if call == 1 {
// First dial (to old worker): signal that we started,
// then block until released or context canceled.
close(firstDialStarted)
select {
case <-releaseFirstDial:
case <-ctx.Done():
return nil, nil, nil, ctx.Err()
}
// If we get here after being released (not canceled),
// return a stale part — this should be discarded.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("stale-part"),
},
}
close(ch)
return nil, ch, func() {}, nil
}
// Second dial (to new worker): return a valid part.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("new-worker-part"),
},
}
return nil, ch, func() {}, nil
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Seed the chat in waiting state so Subscribe does not try an initial
// relay.
chat := seedWaitingChat(ctx, t, db, user, model, "stale-dial-test")
// Subscribe while chat is in "waiting" state — no relay opened.
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Now simulate the chat being picked up by the OLD worker via pubsub.
// This triggers openRelayAsync in the merge loop.
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: oldWorkerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
oldRunningNotify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: oldWorkerID.String(),
}
oldRunningPayload, err := json.Marshal(oldRunningNotify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), oldRunningPayload)
require.NoError(t, err)
// Wait for the first dial goroutine to start (it's blocked in the provider).
select {
case <-firstDialStarted:
case <-ctx.Done():
t.Fatal("timed out waiting for first dial to start")
}
// Simulate interrupt: chat goes to "waiting".
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
waitingNotify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusWaiting),
}
waitingPayload, err := json.Marshal(waitingNotify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), waitingPayload)
require.NoError(t, err)
// Wait for the merge loop to process the waiting notification
// and emit the status event before publishing the new running
// notification. This avoids time.Sleep (banned by project
// policy) and provides a deterministic sync point.
require.Eventually(t, func() bool {
select {
case event := <-events:
return event.Type == codersdk.ChatStreamEventTypeStatus &&
event.Status != nil &&
event.Status.Status == codersdk.ChatStatusWaiting
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// Now the chat transitions to running on the NEW worker.
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: newWorkerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
runningNotify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: newWorkerID.String(),
}
runningPayload, err := json.Marshal(runningNotify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), runningPayload)
require.NoError(t, err)
// Now release the first dial (if it wasn't already canceled).
close(releaseFirstDial)
// The subscriber should receive parts from the NEW worker, not the stale one.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "new-worker-part" {
return true
}
// If we get the stale part, the bug is present.
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "stale-part" {
t.Fatal("received stale part from old worker — relay did not cancel in-flight dial")
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// Drain the events channel for a while to ensure no late-arriving
// stale part sneaks in after the require.Eventually above returned.
// This closes the timing gap where "stale-part" could arrive after
// "new-worker-part" was already consumed.
require.Never(t, func() bool {
select {
case event := <-events:
return event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "stale-part"
default:
return false
}
}, 2*time.Second, testutil.IntervalFast)
}
// TestSubscribeCancelDuringInFlightDial verifies that calling the
// subscription's cancel function while a relay dial goroutine is
// still blocking in the provider causes the provider's context to
// be canceled and the goroutine to return cleanly.
func TestSubscribeCancelDuringInFlightDial(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
dialStarted := make(chan struct{})
dialExited := make(chan struct{})
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
// Signal the dial has started, then block until the context
// is canceled.
close(dialStarted)
<-ctx.Done()
close(dialExited)
return nil, nil, nil, ctx.Err()
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Seed the chat in waiting state so Subscribe does not open a
// synchronous relay.
chat := seedWaitingChat(ctx, t, db, user, model, "cancel-inflight-dial")
_, _, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
// Publish a running notification to trigger openRelayAsync.
notify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: workerID.String(),
}
payload, err := json.Marshal(notify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
require.NoError(t, err)
// Wait for the dial goroutine to block inside the provider.
select {
case <-dialStarted:
case <-ctx.Done():
t.Fatal("timed out waiting for dial to start")
}
// Cancel the subscription while the dial is still in-flight.
cancel()
// The provider context must be canceled, causing the goroutine
// to return cleanly.
require.Eventually(t, func() bool {
select {
case <-dialExited:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
}
// TestSubscribeRelayRunningToRunningSwitch verifies that when a chat
// transitions directly from running(workerA) to running(workerB)
// without an intermediate waiting state, the relay switches to the
// new worker and discards parts from the old one.
func TestSubscribeRelayRunningToRunningSwitch(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerA := uuid.New()
workerB := uuid.New()
subscriberID := uuid.New()
// Gate to hold workerA's dial until we verify cancellation.
dialAStarted := make(chan struct{})
dialAExited := make(chan struct{})
var callCount atomic.Int32
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
if call == 1 {
// First dial (to workerA): signal that we started,
// then block until the context is canceled.
close(dialAStarted)
<-ctx.Done()
close(dialAExited)
return nil, nil, nil, ctx.Err()
}
// Second dial (to workerB): return a valid part.
ch := make(chan codersdk.ChatStreamEvent, 10)
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("worker-b-part"),
},
}
return nil, ch, func() {}, nil
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Seed the chat in waiting state so Subscribe does not open a relay.
chat := seedWaitingChat(ctx, t, db, user, model, "running-to-running")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Transition to running on workerA.
notifyA := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: workerA.String(),
}
payloadA, err := json.Marshal(notifyA)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadA)
require.NoError(t, err)
// Wait for the workerA dial goroutine to block inside the
// provider before publishing the workerB notification.
select {
case <-dialAStarted:
case <-ctx.Done():
t.Fatal("timed out waiting for workerA dial to start")
}
// Immediately transition to running on workerB (no waiting in
// between). This should cancel workerA's in-flight dial.
notifyB := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: workerB.String(),
}
payloadB, err := json.Marshal(notifyB)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadB)
require.NoError(t, err)
// Verify that the relay canceled workerA's stale dial.
require.Eventually(t, func() bool {
select {
case <-dialAExited:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// We should receive the part from workerB.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "worker-b-part" {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.Equal(t, 2, int(callCount.Load()))
}
// TestSubscribeRelayFailedDialRetries verifies that when an async relay
// dial fails (returns an error), the merge loop schedules a reconnect
// timer and eventually re-dials successfully. This exercises the
// result.parts == nil path and the scheduleRelayReconnect() logic.
func TestSubscribeRelayFailedDialRetries(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
remoteWorkerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
if call == 1 {
// First dial: fail with an error to trigger
// scheduleRelayReconnect via the result.parts == nil path.
return nil, nil, nil, xerrors.New("transient dial failure")
}
// Second dial: succeed and return a part.
ch := make(chan codersdk.ChatStreamEvent, 10)
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("retry-success"),
},
}
return nil, ch, func() {}, nil
}
mclk := quartz.NewMock(t)
// Trap the reconnect timer so we can fire it deterministically.
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Seed the chat in waiting state so Subscribe does not open a
// synchronous relay dial.
chat := seedWaitingChat(ctx, t, db, user, model, "failed-dial-retry")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Now mark the chat as running on the remote worker in the DB.
// The reconnect timer calls params.DB.GetChatByID to check if
// the chat is still running on a remote worker, so this must be
// set before we advance the clock.
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: remoteWorkerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
// Publish a running notification with a remote workerID to
// trigger openRelayAsync. The first dial will fail, causing
// scheduleRelayReconnect to be called.
notify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: remoteWorkerID.String(),
}
payload, err := json.Marshal(notify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
require.NoError(t, err)
// Wait for the reconnect timer to be created (after the failed
// dial), then advance the mock clock to fire it.
trapReconnect.MustWait(ctx).MustRelease(ctx)
mclk.Advance(500 * time.Millisecond).MustWait(ctx)
// The merge loop re-checks the DB, sees the chat is still
// running on the remote worker, and dials again. The second
// dial succeeds.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "retry-success" {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.GreaterOrEqual(t, int(callCount.Load()), 2)
}
// TestSubscribeRunningLocalWorkerClosesRelay verifies that when a chat
// is running on a remote worker and a pubsub notification arrives
// saying the local worker (subscriberID) now owns the chat, the
// existing relay is closed and no new dial is started (the local
// worker serves directly without relaying).
func TestSubscribeRunningLocalWorkerClosesRelay(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
remoteWorkerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
ch := make(chan codersdk.ChatStreamEvent, 10)
if call == 1 {
// Initial synchronous dial to the remote worker.
ch <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessageText("remote-part"),
},
}
// Keep channel open so the relay stays active.
}
return nil, ch, func() {}, nil
}
subscriber := newTestServer(t, db, ps, subscriberID, provider, nil)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat := seedRemoteRunningChat(
ctx,
t,
db,
user,
model,
remoteWorkerID,
"local-worker-closes-relay",
)
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Consume the remote-part from the initial relay.
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == "remote-part" {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
// Notify that the LOCAL worker now owns the chat. This should
// close the relay without opening a new one.
notify := coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: subscriberID.String(),
}
payload, err := json.Marshal(notify)
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
require.NoError(t, err)
// Give the system time to process the notification. No additional
// dial should happen — only the initial synchronous one.
require.Never(t, func() bool {
return int(callCount.Load()) > 1
}, 2*time.Second, testutil.IntervalFast)
require.Equal(t, 1, int(callCount.Load()),
"only the initial synchronous dial should have happened")
}
// TestSubscribeRelayMultipleReconnects verifies that the reconnect
// loop handles multiple consecutive relay drops, proving it is
// robust across repeated iterations — not just the single reconnect
// already covered by TestSubscribeRelayReconnectsOnDrop.
func TestSubscribeRelayMultipleReconnects(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call := callCount.Add(1)
ch := make(chan codersdk.ChatStreamEvent, 10)
part := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: "assistant",
Part: codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: fmt.Sprintf("relay-%d", call),
},
},
}
ch <- part
if call <= 2 {
// First two dials: close channel to simulate relay
// drop. This triggers scheduleRelayReconnect.
close(ch)
}
// Third dial: keep channel open.
return nil, ch, func() {}, nil
}
mclk := quartz.NewMock(t)
// Trap the reconnect timer so we can fire both reconnects
// deterministically.
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
chat := seedRemoteRunningChat(
ctx,
t,
db,
user,
model,
workerID,
"multiple-reconnects",
)
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Helper to consume a specific relay part.
consumePart := func(text string) {
t.Helper()
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
event.MessagePart != nil &&
event.MessagePart.Part.Text == text {
return true
}
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
}
// First relay: consumed immediately (synchronous dial).
consumePart("relay-1")
// First relay drops → reconnect timer created. Advance clock
// to fire it.
trapReconnect.MustWait(ctx).MustRelease(ctx)
mclk.Advance(500 * time.Millisecond).MustWait(ctx)
// Second relay part.
consumePart("relay-2")
// Second relay drops → another reconnect timer. Advance again.
trapReconnect.MustWait(ctx).MustRelease(ctx)
mclk.Advance(500 * time.Millisecond).MustWait(ctx)
// Third relay part (channel stays open).
consumePart("relay-3")
require.GreaterOrEqual(t, int(callCount.Load()), 3)
}