Files
Cian Johnston 2f855904be refactor: add dbgen chat generators and migrate test boilerplate (#24497)
- Adds chat-related dbgen generators covering defaults, overrides, and message field mapping.
- Replaces raw single-row chat, message, provider, and model-config setup in tests with dbgen helpers.
- Simplifies chat seed helpers after moving fixture setup into dbgen.

> Generated with [Coder Agents](https://coder.com/agents).
2026-05-01 13:29:33 +01:00

797 lines
26 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package chatd_test
import (
"context"
"database/sql"
"encoding/json"
"io"
"math"
"net/http"
"net/http/httptest"
"regexp"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
osschatd "github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/codersdk"
entchatd "github.com/coder/coder/v2/enterprise/coderd/x/chatd"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// mulPhi multiplies a duration by math.Phi to compute the next
// step in retry.Retrier's φ-growth backoff sequence. If
// TestRelayReconnectUsesExponentialBackoff starts failing after a
// retry library bump, check whether the growth factor has changed.
func mulPhi(d time.Duration) time.Duration {
return time.Duration(float64(d) * math.Phi)
}
// setChatRunningAndPublish marks the chat row as running on workerID
// and publishes a matching status notification. It keeps the DB row
// and pubsub notification in sync so the async reconnect loop
// re-dials on each timer fire (the reconnect branch re-checks DB
// status before calling openRelayAsync).
func setChatRunningAndPublish(
ctx context.Context,
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
chatID, workerID uuid.UUID,
) {
t.Helper()
now := time.Now()
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: now, Valid: true},
HeartbeatAt: sql.NullTime{Time: now, Valid: true},
})
require.NoError(t, err)
payload, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusRunning),
WorkerID: workerID.String(),
})
require.NoError(t, err)
require.NoError(t, ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), payload))
}
// TestRelayDialErrorIsUnrecoverable locks the classification policy.
// Adding a new HTTP status to the unrecoverable set should force a
// test edit too.
func TestRelayDialErrorIsUnrecoverable(t *testing.T) {
t.Parallel()
cases := []struct {
name string
status int
want bool
}{
{"unauthorized", http.StatusUnauthorized, true},
{"forbidden", http.StatusForbidden, true},
{"internal_server", http.StatusInternalServerError, false},
{"bad_gateway", http.StatusBadGateway, false},
{"service_unavailable", http.StatusServiceUnavailable, false},
{"too_many_requests", http.StatusTooManyRequests, false},
{"pre_response", 0, false},
{"bad_request", http.StatusBadRequest, false},
{"not_found", http.StatusNotFound, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
e := &entchatd.RelayDialError{HTTPStatus: tc.status, Err: io.EOF}
require.Equal(t, tc.want, e.IsUnrecoverable(),
"status=%d", tc.status)
})
}
}
// TestRelayReconnectUsesExponentialBackoff asserts that the reconnect
// timer follows the φ-growth sequence produced by
// github.com/coder/retry's defaults, floored at relayRetryFloor.
func TestRelayReconnectUsesExponentialBackoff(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var failCount atomic.Int32
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
failCount.Add(1)
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: http.StatusBadGateway,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-backoff")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Kick the async relay loop and keep the DB row in sync so
// each reconnect timer fire triggers another dial.
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID)
// Expected sequence from retry.Retrier math:
// attempt 1 → floor (500ms)
// attempt n → prev × φ (capped at ceil)
floor := 500 * time.Millisecond
expected := []time.Duration{
floor,
mulPhi(floor),
mulPhi(mulPhi(floor)),
mulPhi(mulPhi(mulPhi(floor))),
mulPhi(mulPhi(mulPhi(mulPhi(floor)))),
}
for i, want := range expected {
call := trapReconnect.MustWait(ctx)
require.Equal(t, want, call.Duration,
"attempt %d: want %v got %v", i+1, want, call.Duration)
call.MustRelease(ctx)
mclk.Advance(want).MustWait(ctx)
}
// We expect 1 initial attempt + 5 reconnects fired by the
// trapped timer = 6 dials before the cap-check runs. Use
// Eventually so we don't race the final dial goroutine that
// the last Advance kicked off.
require.Eventually(t, func() bool {
return failCount.Load() >= 6
}, testutil.WaitShort, testutil.IntervalFast,
"expected 6 dials, got %d", failCount.Load())
// The events channel must remain open - we're still under the
// cap.
select {
case ev, open := <-events:
if !open {
t.Fatalf("events channel closed prematurely; retries should continue below cap")
}
// Allow through events that might have been queued; just
// confirm it's not a terminal error.
if ev.Type == codersdk.ChatStreamEventTypeError {
t.Fatalf("unexpected terminal error: %v", ev.Error)
}
default:
}
}
// TestRelayReconnectResetsOnSuccess exercises the path where a
// successful dial resets the retry state so the next failure starts
// over at the floor delay.
// TestRelayRepeatedDropsHitCap verifies the cap covers a peer that
// accepts the handshake and immediately drops it. Without a proper
// cap, such a peer would produce one reconnect per floor delay
// forever. The retry counter must accumulate across dial-success /
// parts-close cycles so the cap trips.
func TestRelayRepeatedDropsHitCap(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
opened := make(chan chan codersdk.ChatStreamEvent, 32)
var call atomic.Int32
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
call.Add(1)
ch := make(chan codersdk.ChatStreamEvent, 1)
opened <- ch
return nil, ch, func() {}, nil
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-drops")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Kick off the first async dial.
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID)
// Close the first dial's parts channel so the merge loop
// schedules a reconnect. Then advance 6 reconnect timers,
// closing the parts channel each time so the cycle is:
// dial -> success -> parts-close -> next() -> reconnect.
// 1 initial dial + 6 timer-driven dials = 7 total; the 7th
// parts-close trips the cap.
for i := 0; i < 7; i++ {
var ch chan codersdk.ChatStreamEvent
select {
case ch = <-opened:
case <-ctx.Done():
t.Fatalf("timed out waiting for dial %d", i+1)
}
// Closing the parts channel triggers the relayPartsCh
// close branch, which calls retryState.next() and
// schedules the next reconnect.
close(ch)
if i == 6 {
// 7th parts-close should trip the cap; no more
// reconnect timers.
break
}
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
mclk.Advance(call.Duration).MustWait(ctx)
}
// A terminal error event must arrive on the events channel.
var errEvent *codersdk.ChatStreamEvent
require.Eventually(t, func() bool {
select {
case ev, open := <-events:
if !open {
return errEvent != nil
}
if ev.Type == codersdk.ChatStreamEventTypeError {
errEvent = &ev
return true
}
return false
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast,
"expected a terminal error event after repeated drops hit cap")
require.NotNil(t, errEvent.Error)
require.Contains(t, errEvent.Error.Message, "relay connection failed")
// We should have observed exactly 7 dials before tear-down.
require.Equal(t, int32(7), call.Load(),
"expected 7 dials (1 initial + 6 reconnect retries) before cap")
}
// TestRelayStopsAfterIntermittentCap verifies the cap-reached
// tear-down path: after N intermittent failures the merge loop emits
// one error event, closes the events channel, and stops dialing.
func TestRelayStopsAfterIntermittentCap(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
callCount.Add(1)
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: http.StatusBadGateway,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-cap")
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID)
// Advance through N consecutive reconnect timers. Each one
// triggers a dial, which fails and schedules the next timer.
// After the Nth failure the retry state says giveUp=true on
// the next .next() call, so the merge loop tears down.
for i := 0; i < 6; i++ {
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
mclk.Advance(call.Duration).MustWait(ctx)
}
// Wait for the terminal error event to arrive. mergedEvents
// closes inside the enterprise merge goroutine, but OSS only
// nil-outs relayEvents on close - the outer events channel
// stays open for pubsub/local, so we wait for the error event
// itself rather than channel closure.
var errEvent *codersdk.ChatStreamEvent
require.Eventually(t, func() bool {
select {
case ev, open := <-events:
if !open {
return errEvent != nil
}
if ev.Type == codersdk.ChatStreamEventTypeError {
errEvent = &ev
return true
}
return false
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast,
"expected a terminal error event")
require.NotNil(t, errEvent, "expected a terminal error event")
require.NotNil(t, errEvent.Error)
require.Contains(t, errEvent.Error.Message, "relay connection failed")
require.Contains(t, errEvent.Error.Message, "6")
// Ensure the cap fires at attempt N+1 - the retry state allows
// relayMaxRetries successful next() calls before flipping
// giveUp. With one initial dial + 6 reconnect-timer fires the
// 7th .next() trips the cap and tears down, so we see 7 dials
// total and nothing further.
totalDials := callCount.Load()
require.Equal(t, int32(7), totalDials,
"expected exactly relayMaxRetries+1 dials before cap; got %d", totalDials)
}
// chatByIDErrorStore wraps a database.Store and forces GetChatByID
// to return a caller-supplied error once after N successful calls.
// This lets the initial Subscribe call succeed (OSS's initial state
// load needs a real Chat to wire up the relay) while subsequent
// reconnect-branch calls exercise the DB-error retry path.
type chatByIDErrorStore struct {
database.Store
err error
okRemain atomic.Int32 // number of calls allowed to delegate before erroring.
}
func (s *chatByIDErrorStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) {
if s.okRemain.Add(-1) >= 0 {
return s.Store.GetChatByID(ctx, id)
}
return database.Chat{}, s.err
}
// TestRelayReconnectStopsAfterDBErrorCap verifies the reconnect-timer
// branch's DB-error path shares the same retry budget as dial
// failures and trips the cap after enough consecutive DB errors.
func TestRelayReconnectStopsAfterDBErrorCap(t *testing.T) {
t.Parallel()
realDB, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
callCount.Add(1)
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: http.StatusBadGateway,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
// The server sees a DB whose GetChatByID always errors after
// the initial Subscribe snapshot load. Other methods delegate
// to the real DB, so seeding below still works.
failingDB := &chatByIDErrorStore{
Store: realDB,
err: xerrors.New("mock: GetChatByID always fails"),
}
// Allow one successful GetChatByID (the Subscribe preamble's
// initial state load). All subsequent calls return the mock
// error, exercising the reconnect-branch DB-error path.
failingDB.okRemain.Store(1)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, realDB)
chat := seedWaitingChat(t, realDB, org.ID, user, model, "relay-db-error")
subscriber := newTestServer(t, failingDB, ps, subscriberID, dialer, mclk)
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Flip to running so the merge loop starts an async dial. The
// dial fails (attempts=1, reconnect scheduled). From there each
// reconnect timer fires, the merge loop calls GetChatByID, the
// failing DB returns an error, and retryState.next() increments.
//
// Budget: 1 dial-failure + 6 DB-failures = 7 next() calls; the
// 7th trips the cap.
setChatRunningAndPublish(ctx, t, realDB, ps, chat.ID, workerID)
for i := 0; i < 6; i++ {
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
mclk.Advance(call.Duration).MustWait(ctx)
}
var errEvent *codersdk.ChatStreamEvent
require.Eventually(t, func() bool {
select {
case ev, open := <-events:
if !open {
return errEvent != nil
}
if ev.Type == codersdk.ChatStreamEventTypeError {
errEvent = &ev
return true
}
return false
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast,
"expected terminal error event after DB-error cap")
require.NotNil(t, errEvent.Error)
require.Contains(t, errEvent.Error.Message, "relay connection failed")
require.Contains(t, errEvent.Error.Message, "6")
// Exactly 1 dial fired: the one that triggered the initial
// reconnect schedule. All subsequent next() calls come from the
// DB-error branch without calling the dialer.
require.Equal(t, int32(1), callCount.Load(),
"expected exactly 1 dial; reconnects should short-circuit on DB error")
}
// TestRelayStopsImmediatelyOnUnauthorized tests the unrecoverable
// branch and its table of status codes.
func TestRelayStopsImmediatelyOnUnauthorized(t *testing.T) {
t.Parallel()
cases := []struct {
name string
status int
wantUnrecoverable bool
wantMsgContains string
}{
{"401", http.StatusUnauthorized, true, "401"},
{"403", http.StatusForbidden, true, "403"},
{"500_intermittent", http.StatusInternalServerError, false, ""},
{"zero_intermittent", 0, false, ""},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
var callCount atomic.Int32
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
callCount.Add(1)
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: tc.status,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model,
"relay-unrec-"+tc.name)
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID)
if tc.wantUnrecoverable {
// First dial should tear the relay down.
var errEvent *codersdk.ChatStreamEvent
require.Eventually(t, func() bool {
select {
case ev, open := <-events:
if !open {
return errEvent != nil
}
if ev.Type == codersdk.ChatStreamEventTypeError {
errEvent = &ev
return true
}
return false
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast,
"expected terminal error event")
require.NotNil(t, errEvent)
require.Contains(t, errEvent.Error.Message, "relay authentication failed")
require.Contains(t, errEvent.Error.Message, tc.wantMsgContains)
require.Equal(t, int32(1), callCount.Load(),
"unrecoverable errors must not retry; got %d dials", callCount.Load())
} else {
// Intermittent: fire one reconnect timer
// and confirm the dialer is called again.
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
mclk.Advance(call.Duration).MustWait(ctx)
require.Eventually(t, func() bool {
return callCount.Load() >= 2
}, testutil.WaitShort, testutil.IntervalFast,
"intermittent should retry at least once")
}
})
}
}
// TestRelayBackoffResetsOnStatusChange checks that closeRelay (driven
// by a status notification) resets the retry counter so subsequent
// dials against a new target start at the floor delay.
func TestRelayBackoffResetsOnStatusChange(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID1 := uuid.New()
workerID2 := uuid.New()
subscriberID := uuid.New()
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: http.StatusBadGateway,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-reset-on-status")
_, _, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Drive the async openRelayAsync path with workerID1.
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID1)
// Drive 3 intermittent failures so attempts=3 and the delay
// has grown past the floor. After each loop iteration the 4th
// reconnect timer is queued - consume it too so our later
// assertion sees the reset's timer, not a stale one.
for i := 0; i < 3; i++ {
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
mclk.Advance(call.Duration).MustWait(ctx)
}
// Grab the next trapped timer (the grown one scheduled after
// the 3rd dial fails) but don't advance it - we want to see it
// replaced by a fresh floor-delay timer after the reset.
grown := trapReconnect.MustWait(ctx)
require.Greater(t, grown.Duration, 500*time.Millisecond,
"sanity: pre-reset delay should have grown past the floor")
grown.MustRelease(ctx)
// Flip the chat to waiting; closeRelay runs (because the
// status notification no longer points at a running peer) and
// should reset the retry state.
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
waitingPayload, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusWaiting),
})
require.NoError(t, err)
require.NoError(t, ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), waitingPayload))
// Flip back to running on a different worker. This triggers a
// fresh openRelayAsync which fails, arming a reconnect timer.
// That timer's delay must be the floor, proving the reset.
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID2)
call := trapReconnect.MustWait(ctx)
require.Equal(t, 500*time.Millisecond, call.Duration,
"retry state must reset after status change; got grown delay %v", call.Duration)
call.MustRelease(ctx)
}
// TestRelayBackoffRespectsContextCancel is a regression guard: the
// reconnect timer must respect ctx cancellation promptly.
func TestRelayBackoffRespectsContextCancel(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
) {
return nil, nil, nil, &entchatd.RelayDialError{
HTTPStatus: http.StatusBadGateway,
Err: io.EOF,
}
}
mclk := quartz.NewMock(t)
trapReconnect := mclk.Trap().NewTimer("reconnect")
defer trapReconnect.Close()
subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-cancel")
subCtx, subCancel := context.WithCancel(ctx)
_, events, cancel, ok := subscriber.Subscribe(subCtx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID)
// Wait for the first reconnect timer to arm.
call := trapReconnect.MustWait(ctx)
call.MustRelease(ctx)
// Cancel the subscriber context. The events channel should
// close promptly (the merge goroutine's select exits on
// ctx.Done).
subCancel()
done := make(chan struct{})
go func() {
defer close(done)
for {
if _, open := <-events; !open {
return
}
}
}()
select {
case <-done:
case <-time.After(testutil.WaitShort):
t.Fatal("events channel did not close after ctx cancel")
}
}
// TestDialRelayReal401 exercises the real dialRelay path against an
// httptest server that returns 401 on the stream endpoint. It
// validates that the websocket library's handshake failure
// propagates through as *RelayDialError with HTTPStatus == 401.
//
// This is the one test that uses the real coder/websocket library
// on the failure path - a safety net against library upgrades
// silently breaking status capture.
func TestDialRelayReal401(t *testing.T) {
t.Parallel()
// An httptest server that 401s every request on the stream
// endpoint. Any other path gets a 404.
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if !streamPathRE.MatchString(r.URL.Path) {
http.NotFound(rw, r)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
_, _ = rw.Write([]byte(`{"message":"unauthorized"}`))
}))
t.Cleanup(srv.Close)
db, _ := dbtestutil.NewDB(t)
workerID := uuid.New()
subscriberID := uuid.New()
// Wire real config (no DialerFn override) so dialRelay runs
// end-to-end against the httptest server. Seeding a waiting
// chat (below) keeps Subscribe's initial synchronous dial a
// no-op; we then push a running status notification to the
// merge loop so it invokes dialRelay via the async path, where
// the 401 tear-down logic lives.
cfg := entchatd.MultiReplicaSubscribeConfig{
ResolveReplicaAddress: func(_ context.Context, _ uuid.UUID) (string, bool) {
return srv.URL, true
},
ReplicaHTTPClient: srv.Client(),
ReplicaIDFn: func() uuid.UUID { return subscriberID },
}
subscribeFn := entchatd.NewMultiReplicaSubscribeFn(cfg)
ctx := testutil.Context(t, testutil.WaitMedium)
user, org, model := seedChatDependencies(t, db)
// Seed a waiting chat - no sync dial - then push a running
// status notification to trigger the async dial via the real
// dialRelay path.
chat := seedWaitingChat(t, db, org.ID, user, model, "relay-real-401")
statusCh := make(chan osschatd.StatusNotification, 1)
evs := subscribeFn(ctx, osschatd.SubscribeFnParams{
ChatID: chat.ID,
Chat: chat,
WorkerID: subscriberID,
StatusNotifications: statusCh,
RequestHeader: http.Header{codersdk.SessionTokenHeader: {"test-token"}},
DB: db,
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
})
statusCh <- osschatd.StatusNotification{
Status: database.ChatStatusRunning,
WorkerID: workerID,
}
// Wait for a terminal error event. On a real 401 handshake,
// the classifier flags it unrecoverable → one dial, then
// error event, then channel close.
var errEvent *codersdk.ChatStreamEvent
deadline := time.After(testutil.WaitMedium)
waitErr:
for {
select {
case ev, open := <-evs:
if !open {
break waitErr
}
if ev.Type == codersdk.ChatStreamEventTypeError {
errEvent = &ev
}
case <-deadline:
break waitErr
}
}
require.NotNil(t, errEvent, "expected terminal error event from real 401 dial")
require.NotNil(t, errEvent.Error)
require.Contains(t, errEvent.Error.Message, "relay authentication failed")
require.Contains(t, errEvent.Error.Message, "401")
}
// streamPathRE matches the chat stream endpoint path built by
// buildRelayURL. Compiled at package scope so the httptest handler
// below doesn't pay regexp.Compile per request.
var streamPathRE = regexp.MustCompile(
`^/api/experimental/chats/[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}/stream$`,
)