fix(coderd/x/chatd/chatloop): use stream silence timeout (#25782) (#25786)

Cherry-pick of https://github.com/coder/coder/pull/25782

Original PR: #25782 — fix(coderd/x/chatd/chatloop): use stream silence
timeout
Merge commit: 7e2f7198dd
Requested by: @ethanndickson

Co-authored-by: Ethan <ethanndickson@gmail.com>
This commit is contained in:
github-actions[bot]
2026-05-28 11:29:14 -04:00
committed by GitHub
parent eb8b062b1d
commit 85d39b3dbe
2 changed files with 265 additions and 101 deletions
+75 -44
View File
@@ -39,10 +39,11 @@ const (
// prevents infinite compaction loops when the model keeps
// hitting the context limit after summarization.
maxCompactionRetries = 3
// defaultStartupTimeout bounds how long an individual
// model attempt may spend starting to respond before
// defaultStreamSilenceTimeout bounds how long an individual
// model attempt may go without receiving a stream part before
// the attempt is canceled and retried.
defaultStartupTimeout = 60 * time.Second
defaultStreamSilenceTimeout = 10 * time.Minute
streamSilenceGuardTimerTag = "streamSilenceGuard"
)
var (
@@ -53,8 +54,8 @@ var (
// the run should terminate cleanly after persistence.
ErrStopAfterTool = xerrors.New("stop after tool")
errStartupTimeout = xerrors.New(
"chat response did not start before the startup timeout",
errStreamSilenceTimeout = xerrors.New(
"chat stream was silent for longer than the configured timeout",
)
)
@@ -114,14 +115,14 @@ type RunOptions struct {
Messages []fantasy.Message
Tools []fantasy.AgentTool
MaxSteps int
// StartupTimeout bounds how long each model attempt may
// spend opening the provider stream and waiting for its
// first stream part before the attempt is canceled and
// retried. Zero uses the production default.
StartupTimeout time.Duration
// Clock creates startup guard timers. In production use a
// real clock; tests can inject quartz.NewMock(t) to make
// startup timeout behavior deterministic.
// StreamSilenceTimeout bounds how long each model attempt
// may go without receiving a stream part before the
// attempt is canceled and retried. Zero uses the
// production default.
StreamSilenceTimeout time.Duration
// Clock creates stream silence guard timers. In production
// use a real clock; tests can inject quartz.NewMock(t) to
// make timeout behavior deterministic.
Clock quartz.Clock
ActiveTools []string
@@ -364,8 +365,8 @@ func Run(ctx context.Context, opts RunOptions) error {
if opts.MaxSteps <= 0 {
opts.MaxSteps = 1
}
if opts.StartupTimeout <= 0 {
opts.StartupTimeout = defaultStartupTimeout
if opts.StreamSilenceTimeout <= 0 {
opts.StreamSilenceTimeout = defaultStreamSilenceTimeout
}
if opts.Clock == nil {
opts.Clock = quartz.NewReal()
@@ -468,7 +469,7 @@ func Run(ctx context.Context, opts RunOptions) error {
provider,
modelName,
opts.Clock,
opts.StartupTimeout,
opts.StreamSilenceTimeout,
func(attemptCtx context.Context) (fantasy.StreamResponse, error) {
return opts.Model.Stream(attemptCtx, call)
},
@@ -782,9 +783,9 @@ func prepareMessagesForRequest(
return canonical, prompt, nil
}
// guardedAttempt owns an attempt-scoped context and startup guard
// guardedAttempt owns an attempt-scoped context and silence guard
// around a provider stream. release is idempotent and frees the
// attempt-scoped timer/context. finish canonicalizes startup timeout
// attempt-scoped timer/context. finish canonicalizes silence timeout
// errors before the retry loop classifies them.
type guardedAttempt struct {
ctx context.Context
@@ -793,47 +794,77 @@ type guardedAttempt struct {
finish func(error) error
}
// startupGuard arbitrates whether an attempt times out during
// stream startup. Exactly one outcome wins: the timer cancels
// the attempt, or the first-part path disarms the timer.
type startupGuard struct {
timer *quartz.Timer
cancel context.CancelCauseFunc
once sync.Once
// streamSilenceGuard arbitrates whether an attempt times out while
// waiting for the next stream part. Exactly one outcome wins: the
// timer cancels the attempt, or release disarms the timer.
type streamSilenceGuard struct {
mu sync.Mutex
timer *quartz.Timer
cancel context.CancelCauseFunc
timeout time.Duration
settled bool
}
func newStartupGuard(
func newStreamSilenceGuard(
clock quartz.Clock,
timeout time.Duration,
cancel context.CancelCauseFunc,
) *startupGuard {
guard := &startupGuard{cancel: cancel}
guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard")
) *streamSilenceGuard {
guard := &streamSilenceGuard{
cancel: cancel,
timeout: timeout,
}
guard.timer = clock.AfterFunc(
timeout,
guard.onTimeout,
streamSilenceGuardTimerTag,
)
return guard
}
func (g *startupGuard) onTimeout() {
g.once.Do(func() {
g.cancel(errStartupTimeout)
})
func (g *streamSilenceGuard) settle() bool {
g.mu.Lock()
defer g.mu.Unlock()
if g.settled {
return false
}
g.settled = true
return true
}
func (g *startupGuard) Disarm() {
g.once.Do(func() {
g.timer.Stop()
})
func (g *streamSilenceGuard) onTimeout() {
if !g.settle() {
return
}
g.cancel(errStreamSilenceTimeout)
}
func classifyStartupTimeout(
func (g *streamSilenceGuard) Reset() {
g.mu.Lock()
defer g.mu.Unlock()
if g.settled {
return
}
g.timer.Reset(g.timeout, streamSilenceGuardTimerTag)
}
func (g *streamSilenceGuard) Disarm() {
if !g.settle() {
return
}
g.timer.Stop()
}
func classifyStreamSilenceTimeout(
attemptCtx context.Context,
provider string,
err error,
) error {
if !errors.Is(context.Cause(attemptCtx), errStartupTimeout) {
if !errors.Is(context.Cause(attemptCtx), errStreamSilenceTimeout) {
return err
}
if err == nil {
err = errStartupTimeout
err = errStreamSilenceTimeout
}
return chaterror.WithClassification(err, chaterror.ClassifiedError{
Kind: codersdk.ChatErrorKindStartupTimeout,
@@ -851,7 +882,7 @@ func guardedStream(
metrics *Metrics,
) (guardedAttempt, error) {
attemptCtx, cancelAttempt := context.WithCancelCause(parent)
guard := newStartupGuard(clock, timeout, cancelAttempt)
guard := newStreamSilenceGuard(clock, timeout, cancelAttempt)
var releaseOnce sync.Once
release := func() {
releaseOnce.Do(func() {
@@ -863,7 +894,7 @@ func guardedStream(
streamStart := clock.Now()
stream, err := openStream(attemptCtx)
if err != nil {
err = classifyStartupTimeout(attemptCtx, provider, err)
err = classifyStreamSilenceTimeout(attemptCtx, provider, err)
release()
return guardedAttempt{}, err
}
@@ -877,7 +908,7 @@ func guardedStream(
ctx: attemptCtx,
stream: fantasy.StreamResponse(func(yield func(fantasy.StreamPart) bool) {
for part := range stream {
guard.Disarm()
guard.Reset()
recordTTFT()
if !yield(part) {
return
@@ -886,7 +917,7 @@ func guardedStream(
}),
release: release,
finish: func(err error) error {
return classifyStartupTimeout(attemptCtx, provider, err)
return classifyStreamSilenceTimeout(attemptCtx, provider, err)
},
}, nil
}
@@ -581,13 +581,13 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) {
)
}
func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) {
t.Parallel()
for range 128 {
var cancels atomic.Int32
guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) {
if errors.Is(err, errStartupTimeout) {
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) {
if errors.Is(err, errStreamSilenceTimeout) {
cancels.Add(1)
}
})
@@ -618,17 +618,17 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
}
}
func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) {
t.Parallel()
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
defer cancelAttempt(nil)
guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt)
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt)
guard.Disarm()
guard.onTimeout()
classified := chaterror.Classify(classifyStartupTimeout(
classified := chaterror.Classify(classifyStreamSilenceTimeout(
attemptCtx,
"openai",
xerrors.New("invalid model"),
@@ -638,10 +638,10 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
require.Nil(t, context.Cause(attemptCtx))
}
func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) {
t.Parallel()
const startupTimeout = 5 * time.Millisecond
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
@@ -650,7 +650,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
defer cancel()
mClock := quartz.NewMock(t)
trap := mClock.Trap().AfterFunc("startupGuard")
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer trap.Close()
attempts := 0
@@ -675,10 +675,10 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StartupTimeout: startupTimeout,
Clock: mClock,
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
@@ -694,7 +694,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
}()
trap.MustWait(ctx).MustRelease(ctx)
mClock.Advance(startupTimeout).MustWait(ctx)
mClock.Advance(silenceTimeout).MustWait(ctx)
trap.MustWait(ctx).MustRelease(ctx)
require.NoError(t, awaitRunResult(ctx, t, done))
@@ -710,9 +710,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
)
select {
case cause := <-attemptCause:
require.ErrorIs(t, cause, errStartupTimeout)
require.ErrorIs(t, cause, errStreamSilenceTimeout)
case <-ctx.Done():
t.Fatal("timed out waiting for startup timeout cause")
t.Fatal("timed out waiting for silence timeout cause")
}
}
@@ -728,7 +728,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
t.Run(provider, func(t *testing.T) {
t.Parallel()
const startupTimeout = 5 * time.Millisecond
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
@@ -737,7 +737,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
defer cancel()
mClock := quartz.NewMock(t)
trap := mClock.Trap().AfterFunc("startupGuard")
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer trap.Close()
attempts := 0
@@ -763,10 +763,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StartupTimeout: startupTimeout,
Clock: mClock,
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
@@ -795,10 +795,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
}
}
func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) {
t.Parallel()
const startupTimeout = 5 * time.Millisecond
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
@@ -807,7 +807,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
defer cancel()
mClock := quartz.NewMock(t)
trap := mClock.Trap().AfterFunc("startupGuard")
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer trap.Close()
attempts := 0
@@ -837,10 +837,10 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StartupTimeout: startupTimeout,
Clock: mClock,
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
@@ -856,7 +856,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
}()
trap.MustWait(ctx).MustRelease(ctx)
mClock.Advance(startupTimeout).MustWait(ctx)
mClock.Advance(silenceTimeout).MustWait(ctx)
trap.MustWait(ctx).MustRelease(ctx)
require.NoError(t, awaitRunResult(ctx, t, done))
@@ -872,16 +872,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
)
select {
case cause := <-attemptCause:
require.ErrorIs(t, cause, errStartupTimeout)
require.ErrorIs(t, cause, errStreamSilenceTimeout)
case <-ctx.Done():
t.Fatal("timed out waiting for startup timeout cause")
t.Fatal("timed out waiting for silence timeout cause")
}
}
func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) {
t.Parallel()
const startupTimeout = 5 * time.Millisecond
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
@@ -890,12 +890,17 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
defer cancel()
mClock := quartz.NewMock(t)
trap := mClock.Trap().AfterFunc("startupGuard")
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer armTrap.Close()
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
defer resetTrap.Close()
attempts := 0
retried := false
firstPartYielded := make(chan struct{}, 1)
continueStream := make(chan struct{})
secondPartYielded := make(chan struct{}, 1)
continueToSecond := make(chan struct{})
continueToFinish := make(chan struct{})
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
@@ -910,7 +915,29 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
}
select {
case <-continueStream:
case <-continueToSecond:
case <-ctx.Done():
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
return
}
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeTextDelta,
ID: "text-1",
Delta: "done",
}) {
return
}
select {
case secondPartYielded <- struct{}{}:
default:
}
select {
case <-continueToFinish:
case <-ctx.Done():
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
@@ -920,7 +947,6 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
}
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}
@@ -936,10 +962,10 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StartupTimeout: startupTimeout,
Clock: mClock,
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
@@ -954,23 +980,130 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
})
}()
trap.MustWait(ctx).MustRelease(ctx)
trap.Close()
armTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx)
select {
case <-firstPartYielded:
case <-ctx.Done():
t.Fatal("timed out waiting for first stream part")
}
mClock.Advance(startupTimeout).MustWait(ctx)
close(continueStream)
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
close(continueToSecond)
resetTrap.MustWait(ctx).MustRelease(ctx)
select {
case <-secondPartYielded:
case <-ctx.Done():
t.Fatal("timed out waiting for second stream part")
}
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
close(continueToFinish)
resetTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx)
require.NoError(t, awaitRunResult(ctx, t, done))
require.Equal(t, 1, attempts)
require.False(t, retried)
}
func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) {
t.Parallel()
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
testutil.WaitLong,
)
defer cancel()
mClock := quartz.NewMock(t)
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer armTrap.Close()
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
defer resetTrap.Close()
attempts := 0
firstPartYielded := make(chan struct{}, 1)
attemptCause := make(chan error, 1)
var retries []chatretry.ClassifiedError
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
attempts++
if attempts == 1 {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
return
}
select {
case firstPartYielded <- struct{}{}:
default:
}
<-ctx.Done()
attemptCause <- context.Cause(ctx)
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
}), nil
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
},
}
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
OnRetry: func(
_ int,
_ error,
classified chatretry.ClassifiedError,
_ time.Duration,
) {
retries = append(retries, classified)
},
})
}()
armTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx)
select {
case <-firstPartYielded:
case <-ctx.Done():
t.Fatal("timed out waiting for first stream part")
}
mClock.Advance(silenceTimeout).MustWait(ctx)
armTrap.MustWait(ctx).MustRelease(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx)
require.NoError(t, awaitRunResult(ctx, t, done))
require.Equal(t, 2, attempts)
require.Len(t, retries, 1)
require.Equal(t, codersdk.ChatErrorKindStartupTimeout, retries[0].Kind)
require.True(t, retries[0].Retryable)
require.Equal(t, "openai", retries[0].Provider)
select {
case cause := <-attemptCause:
require.ErrorIs(t, cause, errStreamSilenceTimeout)
case <-ctx.Done():
t.Fatal("timed out waiting for silence timeout cause")
}
}
func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
t.Parallel()
@@ -1014,10 +1147,10 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
t.Fatal("expected Run to panic")
}
func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) {
t.Parallel()
const startupTimeout = 5 * time.Millisecond
const silenceTimeout = 5 * time.Millisecond
ctx, cancel := context.WithTimeout(
context.Background(),
@@ -1026,7 +1159,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
defer cancel()
mClock := quartz.NewMock(t)
trap := mClock.Trap().AfterFunc("startupGuard")
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
defer trap.Close()
attempts := 0
@@ -1052,10 +1185,10 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
done := make(chan error, 1)
go func() {
done <- Run(context.Background(), RunOptions{
Model: model,
MaxSteps: 1,
StartupTimeout: startupTimeout,
Clock: mClock,
Model: model,
MaxSteps: 1,
StreamSilenceTimeout: silenceTimeout,
Clock: mClock,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
@@ -1071,7 +1204,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
}()
trap.MustWait(ctx).MustRelease(ctx)
mClock.Advance(startupTimeout).MustWait(ctx)
mClock.Advance(silenceTimeout).MustWait(ctx)
trap.MustWait(ctx).MustRelease(ctx)
require.NoError(t, awaitRunResult(ctx, t, done))
@@ -1087,9 +1220,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
)
select {
case cause := <-attemptCause:
require.ErrorIs(t, cause, errStartupTimeout)
require.ErrorIs(t, cause, errStreamSilenceTimeout)
case <-ctx.Done():
t.Fatal("timed out waiting for startup timeout cause")
t.Fatal("timed out waiting for silence timeout cause")
}
}