diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index c67e2ee6d0..7a81dc4d6e 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -39,10 +39,11 @@ const ( // prevents infinite compaction loops when the model keeps // hitting the context limit after summarization. maxCompactionRetries = 3 - // defaultStartupTimeout bounds how long an individual - // model attempt may spend starting to respond before + // defaultStreamSilenceTimeout bounds how long an individual + // model attempt may go without receiving a stream part before // the attempt is canceled and retried. - defaultStartupTimeout = 60 * time.Second + defaultStreamSilenceTimeout = 10 * time.Minute + streamSilenceGuardTimerTag = "streamSilenceGuard" ) var ( @@ -53,8 +54,8 @@ var ( // the run should terminate cleanly after persistence. ErrStopAfterTool = xerrors.New("stop after tool") - errStartupTimeout = xerrors.New( - "chat response did not start before the startup timeout", + errStreamSilenceTimeout = xerrors.New( + "chat stream was silent for longer than the configured timeout", ) ) @@ -114,14 +115,14 @@ type RunOptions struct { Messages []fantasy.Message Tools []fantasy.AgentTool MaxSteps int - // StartupTimeout bounds how long each model attempt may - // spend opening the provider stream and waiting for its - // first stream part before the attempt is canceled and - // retried. Zero uses the production default. - StartupTimeout time.Duration - // Clock creates startup guard timers. In production use a - // real clock; tests can inject quartz.NewMock(t) to make - // startup timeout behavior deterministic. + // StreamSilenceTimeout bounds how long each model attempt + // may go without receiving a stream part before the + // attempt is canceled and retried. Zero uses the + // production default. + StreamSilenceTimeout time.Duration + // Clock creates stream silence guard timers. In production + // use a real clock; tests can inject quartz.NewMock(t) to + // make timeout behavior deterministic. Clock quartz.Clock ActiveTools []string @@ -364,8 +365,8 @@ func Run(ctx context.Context, opts RunOptions) error { if opts.MaxSteps <= 0 { opts.MaxSteps = 1 } - if opts.StartupTimeout <= 0 { - opts.StartupTimeout = defaultStartupTimeout + if opts.StreamSilenceTimeout <= 0 { + opts.StreamSilenceTimeout = defaultStreamSilenceTimeout } if opts.Clock == nil { opts.Clock = quartz.NewReal() @@ -468,7 +469,7 @@ func Run(ctx context.Context, opts RunOptions) error { provider, modelName, opts.Clock, - opts.StartupTimeout, + opts.StreamSilenceTimeout, func(attemptCtx context.Context) (fantasy.StreamResponse, error) { return opts.Model.Stream(attemptCtx, call) }, @@ -782,9 +783,9 @@ func prepareMessagesForRequest( return canonical, prompt, nil } -// guardedAttempt owns an attempt-scoped context and startup guard +// guardedAttempt owns an attempt-scoped context and silence guard // around a provider stream. release is idempotent and frees the -// attempt-scoped timer/context. finish canonicalizes startup timeout +// attempt-scoped timer/context. finish canonicalizes silence timeout // errors before the retry loop classifies them. type guardedAttempt struct { ctx context.Context @@ -793,47 +794,77 @@ type guardedAttempt struct { finish func(error) error } -// startupGuard arbitrates whether an attempt times out during -// stream startup. Exactly one outcome wins: the timer cancels -// the attempt, or the first-part path disarms the timer. -type startupGuard struct { - timer *quartz.Timer - cancel context.CancelCauseFunc - once sync.Once +// streamSilenceGuard arbitrates whether an attempt times out while +// waiting for the next stream part. Exactly one outcome wins: the +// timer cancels the attempt, or release disarms the timer. +type streamSilenceGuard struct { + mu sync.Mutex + timer *quartz.Timer + cancel context.CancelCauseFunc + timeout time.Duration + settled bool } -func newStartupGuard( +func newStreamSilenceGuard( clock quartz.Clock, timeout time.Duration, cancel context.CancelCauseFunc, -) *startupGuard { - guard := &startupGuard{cancel: cancel} - guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard") +) *streamSilenceGuard { + guard := &streamSilenceGuard{ + cancel: cancel, + timeout: timeout, + } + guard.timer = clock.AfterFunc( + timeout, + guard.onTimeout, + streamSilenceGuardTimerTag, + ) return guard } -func (g *startupGuard) onTimeout() { - g.once.Do(func() { - g.cancel(errStartupTimeout) - }) +func (g *streamSilenceGuard) settle() bool { + g.mu.Lock() + defer g.mu.Unlock() + if g.settled { + return false + } + g.settled = true + return true } -func (g *startupGuard) Disarm() { - g.once.Do(func() { - g.timer.Stop() - }) +func (g *streamSilenceGuard) onTimeout() { + if !g.settle() { + return + } + g.cancel(errStreamSilenceTimeout) } -func classifyStartupTimeout( +func (g *streamSilenceGuard) Reset() { + g.mu.Lock() + defer g.mu.Unlock() + if g.settled { + return + } + g.timer.Reset(g.timeout, streamSilenceGuardTimerTag) +} + +func (g *streamSilenceGuard) Disarm() { + if !g.settle() { + return + } + g.timer.Stop() +} + +func classifyStreamSilenceTimeout( attemptCtx context.Context, provider string, err error, ) error { - if !errors.Is(context.Cause(attemptCtx), errStartupTimeout) { + if !errors.Is(context.Cause(attemptCtx), errStreamSilenceTimeout) { return err } if err == nil { - err = errStartupTimeout + err = errStreamSilenceTimeout } return chaterror.WithClassification(err, chaterror.ClassifiedError{ Kind: codersdk.ChatErrorKindStartupTimeout, @@ -851,7 +882,7 @@ func guardedStream( metrics *Metrics, ) (guardedAttempt, error) { attemptCtx, cancelAttempt := context.WithCancelCause(parent) - guard := newStartupGuard(clock, timeout, cancelAttempt) + guard := newStreamSilenceGuard(clock, timeout, cancelAttempt) var releaseOnce sync.Once release := func() { releaseOnce.Do(func() { @@ -863,7 +894,7 @@ func guardedStream( streamStart := clock.Now() stream, err := openStream(attemptCtx) if err != nil { - err = classifyStartupTimeout(attemptCtx, provider, err) + err = classifyStreamSilenceTimeout(attemptCtx, provider, err) release() return guardedAttempt{}, err } @@ -877,7 +908,7 @@ func guardedStream( ctx: attemptCtx, stream: fantasy.StreamResponse(func(yield func(fantasy.StreamPart) bool) { for part := range stream { - guard.Disarm() + guard.Reset() recordTTFT() if !yield(part) { return @@ -886,7 +917,7 @@ func guardedStream( }), release: release, finish: func(err error) error { - return classifyStartupTimeout(attemptCtx, provider, err) + return classifyStreamSilenceTimeout(attemptCtx, provider, err) }, }, nil } diff --git a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go index a7a2079d79..64b1d8f97c 100644 --- a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go +++ b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go @@ -581,13 +581,13 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) { ) } -func TestStartupGuard_DisarmAndFireRace(t *testing.T) { +func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) { t.Parallel() for range 128 { var cancels atomic.Int32 - guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) { - if errors.Is(err, errStartupTimeout) { + guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) { + if errors.Is(err, errStreamSilenceTimeout) { cancels.Add(1) } }) @@ -618,17 +618,17 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) { } } -func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) { +func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) { t.Parallel() attemptCtx, cancelAttempt := context.WithCancelCause(context.Background()) defer cancelAttempt(nil) - guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt) + guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt) guard.Disarm() guard.onTimeout() - classified := chaterror.Classify(classifyStartupTimeout( + classified := chaterror.Classify(classifyStreamSilenceTimeout( attemptCtx, "openai", xerrors.New("invalid model"), @@ -638,10 +638,10 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) { require.Nil(t, context.Cause(attemptCtx)) } -func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { +func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -650,7 +650,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -675,10 +675,10 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -694,7 +694,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -710,9 +710,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } } @@ -728,7 +728,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { t.Run(provider, func(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -737,7 +737,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -763,10 +763,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -795,10 +795,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { } } -func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { +func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -807,7 +807,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -837,10 +837,10 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -856,7 +856,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -872,16 +872,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } } -func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { +func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -890,12 +890,17 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer armTrap.Close() + resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) + defer resetTrap.Close() attempts := 0 retried := false firstPartYielded := make(chan struct{}, 1) - continueStream := make(chan struct{}) + secondPartYielded := make(chan struct{}, 1) + continueToSecond := make(chan struct{}) + continueToFinish := make(chan struct{}) model := &chattest.FakeModel{ ProviderName: "openai", StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { @@ -910,7 +915,29 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { } select { - case <-continueStream: + case <-continueToSecond: + case <-ctx.Done(): + _ = yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: ctx.Err(), + }) + return + } + + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeTextDelta, + ID: "text-1", + Delta: "done", + }) { + return + } + select { + case secondPartYielded <- struct{}{}: + default: + } + + select { + case <-continueToFinish: case <-ctx.Done(): _ = yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, @@ -920,7 +947,6 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { } parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, } @@ -936,10 +962,10 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -954,23 +980,130 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) { }) }() - trap.MustWait(ctx).MustRelease(ctx) - trap.Close() - + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) select { case <-firstPartYielded: case <-ctx.Done(): t.Fatal("timed out waiting for first stream part") } - mClock.Advance(startupTimeout).MustWait(ctx) - close(continueStream) + mClock.Advance(silenceTimeout / 2).MustWait(ctx) + close(continueToSecond) + resetTrap.MustWait(ctx).MustRelease(ctx) + select { + case <-secondPartYielded: + case <-ctx.Done(): + t.Fatal("timed out waiting for second stream part") + } + + mClock.Advance(silenceTimeout / 2).MustWait(ctx) + close(continueToFinish) + resetTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) require.Equal(t, 1, attempts) require.False(t, retried) } +func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) { + t.Parallel() + + const silenceTimeout = 5 * time.Millisecond + + ctx, cancel := context.WithTimeout( + context.Background(), + testutil.WaitLong, + ) + defer cancel() + + mClock := quartz.NewMock(t) + armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer armTrap.Close() + resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) + defer resetTrap.Close() + + attempts := 0 + firstPartYielded := make(chan struct{}, 1) + attemptCause := make(chan error, 1) + var retries []chatretry.ClassifiedError + model := &chattest.FakeModel{ + ProviderName: "openai", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + attempts++ + if attempts == 1 { + return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) { + return + } + select { + case firstPartYielded <- struct{}{}: + default: + } + + <-ctx.Done() + attemptCause <- context.Cause(ctx) + _ = yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: ctx.Err(), + }) + }), nil + } + return streamFromParts([]fantasy.StreamPart{{ + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonStop, + }}), nil + }, + } + + done := make(chan error, 1) + go func() { + done <- Run(context.Background(), RunOptions{ + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, + PersistStep: func(_ context.Context, _ PersistedStep) error { + return nil + }, + OnRetry: func( + _ int, + _ error, + classified chatretry.ClassifiedError, + _ time.Duration, + ) { + retries = append(retries, classified) + }, + }) + }() + + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) + select { + case <-firstPartYielded: + case <-ctx.Done(): + t.Fatal("timed out waiting for first stream part") + } + + mClock.Advance(silenceTimeout).MustWait(ctx) + armTrap.MustWait(ctx).MustRelease(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) + + require.NoError(t, awaitRunResult(ctx, t, done)) + require.Equal(t, 2, attempts) + require.Len(t, retries, 1) + require.Equal(t, codersdk.ChatErrorKindStartupTimeout, retries[0].Kind) + require.True(t, retries[0].Retryable) + require.Equal(t, "openai", retries[0].Provider) + select { + case cause := <-attemptCause: + require.ErrorIs(t, cause, errStreamSilenceTimeout) + case <-ctx.Done(): + t.Fatal("timed out waiting for silence timeout cause") + } +} + func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { t.Parallel() @@ -1014,10 +1147,10 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { t.Fatal("expected Run to panic") } -func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { +func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) { t.Parallel() - const startupTimeout = 5 * time.Millisecond + const silenceTimeout = 5 * time.Millisecond ctx, cancel := context.WithTimeout( context.Background(), @@ -1026,7 +1159,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { defer cancel() mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc("startupGuard") + trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) defer trap.Close() attempts := 0 @@ -1052,10 +1185,10 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { done := make(chan error, 1) go func() { done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StartupTimeout: startupTimeout, - Clock: mClock, + Model: model, + MaxSteps: 1, + StreamSilenceTimeout: silenceTimeout, + Clock: mClock, PersistStep: func(_ context.Context, _ PersistedStep) error { return nil }, @@ -1071,7 +1204,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { }() trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(startupTimeout).MustWait(ctx) + mClock.Advance(silenceTimeout).MustWait(ctx) trap.MustWait(ctx).MustRelease(ctx) require.NoError(t, awaitRunResult(ctx, t, done)) @@ -1087,9 +1220,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) { ) select { case cause := <-attemptCause: - require.ErrorIs(t, cause, errStartupTimeout) + require.ErrorIs(t, cause, errStreamSilenceTimeout) case <-ctx.Done(): - t.Fatal("timed out waiting for startup timeout cause") + t.Fatal("timed out waiting for silence timeout cause") } }