mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/x/chatd/chatloop): use stream silence timeout (#25782)
Replaces the 60 second first-token timeout in the chat loop with a 10 minute stream-silence timeout. Previously, the guard bounded only the gap before the first stream part. Once any part arrived the attempt could hang indefinitely if the provider stopped streaming without closing the connection, and even normal long-running responses could be killed after 60 seconds if the provider was slow to emit the first token. The guard now arms when a model attempt opens its stream, resets on every received stream part, and fires after 10 minutes of complete silence. The existing retry path still handles the timeout, and the public `startup_timeout` error kind is preserved to avoid API and frontend churn. 10 minutes matches the default request timeout used by the Anthropic and OpenAI Python SDKs. Closes CODAGT-493
This commit is contained in:
@@ -39,10 +39,11 @@ const (
|
||||
// prevents infinite compaction loops when the model keeps
|
||||
// hitting the context limit after summarization.
|
||||
maxCompactionRetries = 3
|
||||
// defaultStartupTimeout bounds how long an individual
|
||||
// model attempt may spend starting to respond before
|
||||
// defaultStreamSilenceTimeout bounds how long an individual
|
||||
// model attempt may go without receiving a stream part before
|
||||
// the attempt is canceled and retried.
|
||||
defaultStartupTimeout = 60 * time.Second
|
||||
defaultStreamSilenceTimeout = 10 * time.Minute
|
||||
streamSilenceGuardTimerTag = "streamSilenceGuard"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -53,8 +54,8 @@ var (
|
||||
// the run should terminate cleanly after persistence.
|
||||
ErrStopAfterTool = xerrors.New("stop after tool")
|
||||
|
||||
errStartupTimeout = xerrors.New(
|
||||
"chat response did not start before the startup timeout",
|
||||
errStreamSilenceTimeout = xerrors.New(
|
||||
"chat stream was silent for longer than the configured timeout",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -114,14 +115,14 @@ type RunOptions struct {
|
||||
Messages []fantasy.Message
|
||||
Tools []fantasy.AgentTool
|
||||
MaxSteps int
|
||||
// StartupTimeout bounds how long each model attempt may
|
||||
// spend opening the provider stream and waiting for its
|
||||
// first stream part before the attempt is canceled and
|
||||
// retried. Zero uses the production default.
|
||||
StartupTimeout time.Duration
|
||||
// Clock creates startup guard timers. In production use a
|
||||
// real clock; tests can inject quartz.NewMock(t) to make
|
||||
// startup timeout behavior deterministic.
|
||||
// StreamSilenceTimeout bounds how long each model attempt
|
||||
// may go without receiving a stream part before the
|
||||
// attempt is canceled and retried. Zero uses the
|
||||
// production default.
|
||||
StreamSilenceTimeout time.Duration
|
||||
// Clock creates stream silence guard timers. In production
|
||||
// use a real clock; tests can inject quartz.NewMock(t) to
|
||||
// make timeout behavior deterministic.
|
||||
Clock quartz.Clock
|
||||
|
||||
ActiveTools []string
|
||||
@@ -364,8 +365,8 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
if opts.MaxSteps <= 0 {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
if opts.StartupTimeout <= 0 {
|
||||
opts.StartupTimeout = defaultStartupTimeout
|
||||
if opts.StreamSilenceTimeout <= 0 {
|
||||
opts.StreamSilenceTimeout = defaultStreamSilenceTimeout
|
||||
}
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
@@ -468,7 +469,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
provider,
|
||||
modelName,
|
||||
opts.Clock,
|
||||
opts.StartupTimeout,
|
||||
opts.StreamSilenceTimeout,
|
||||
func(attemptCtx context.Context) (fantasy.StreamResponse, error) {
|
||||
return opts.Model.Stream(attemptCtx, call)
|
||||
},
|
||||
@@ -782,9 +783,9 @@ func prepareMessagesForRequest(
|
||||
return canonical, prompt, nil
|
||||
}
|
||||
|
||||
// guardedAttempt owns an attempt-scoped context and startup guard
|
||||
// guardedAttempt owns an attempt-scoped context and silence guard
|
||||
// around a provider stream. release is idempotent and frees the
|
||||
// attempt-scoped timer/context. finish canonicalizes startup timeout
|
||||
// attempt-scoped timer/context. finish canonicalizes silence timeout
|
||||
// errors before the retry loop classifies them.
|
||||
type guardedAttempt struct {
|
||||
ctx context.Context
|
||||
@@ -793,47 +794,77 @@ type guardedAttempt struct {
|
||||
finish func(error) error
|
||||
}
|
||||
|
||||
// startupGuard arbitrates whether an attempt times out during
|
||||
// stream startup. Exactly one outcome wins: the timer cancels
|
||||
// the attempt, or the first-part path disarms the timer.
|
||||
type startupGuard struct {
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
once sync.Once
|
||||
// streamSilenceGuard arbitrates whether an attempt times out while
|
||||
// waiting for the next stream part. Exactly one outcome wins: the
|
||||
// timer cancels the attempt, or release disarms the timer.
|
||||
type streamSilenceGuard struct {
|
||||
mu sync.Mutex
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
timeout time.Duration
|
||||
settled bool
|
||||
}
|
||||
|
||||
func newStartupGuard(
|
||||
func newStreamSilenceGuard(
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
cancel context.CancelCauseFunc,
|
||||
) *startupGuard {
|
||||
guard := &startupGuard{cancel: cancel}
|
||||
guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard")
|
||||
) *streamSilenceGuard {
|
||||
guard := &streamSilenceGuard{
|
||||
cancel: cancel,
|
||||
timeout: timeout,
|
||||
}
|
||||
guard.timer = clock.AfterFunc(
|
||||
timeout,
|
||||
guard.onTimeout,
|
||||
streamSilenceGuardTimerTag,
|
||||
)
|
||||
return guard
|
||||
}
|
||||
|
||||
func (g *startupGuard) onTimeout() {
|
||||
g.once.Do(func() {
|
||||
g.cancel(errStartupTimeout)
|
||||
})
|
||||
func (g *streamSilenceGuard) settle() bool {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.settled {
|
||||
return false
|
||||
}
|
||||
g.settled = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *startupGuard) Disarm() {
|
||||
g.once.Do(func() {
|
||||
g.timer.Stop()
|
||||
})
|
||||
func (g *streamSilenceGuard) onTimeout() {
|
||||
if !g.settle() {
|
||||
return
|
||||
}
|
||||
g.cancel(errStreamSilenceTimeout)
|
||||
}
|
||||
|
||||
func classifyStartupTimeout(
|
||||
func (g *streamSilenceGuard) Reset() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.settled {
|
||||
return
|
||||
}
|
||||
g.timer.Reset(g.timeout, streamSilenceGuardTimerTag)
|
||||
}
|
||||
|
||||
func (g *streamSilenceGuard) Disarm() {
|
||||
if !g.settle() {
|
||||
return
|
||||
}
|
||||
g.timer.Stop()
|
||||
}
|
||||
|
||||
func classifyStreamSilenceTimeout(
|
||||
attemptCtx context.Context,
|
||||
provider string,
|
||||
err error,
|
||||
) error {
|
||||
if !errors.Is(context.Cause(attemptCtx), errStartupTimeout) {
|
||||
if !errors.Is(context.Cause(attemptCtx), errStreamSilenceTimeout) {
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
err = errStartupTimeout
|
||||
err = errStreamSilenceTimeout
|
||||
}
|
||||
return chaterror.WithClassification(err, chaterror.ClassifiedError{
|
||||
Kind: codersdk.ChatErrorKindStartupTimeout,
|
||||
@@ -851,7 +882,7 @@ func guardedStream(
|
||||
metrics *Metrics,
|
||||
) (guardedAttempt, error) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(parent)
|
||||
guard := newStartupGuard(clock, timeout, cancelAttempt)
|
||||
guard := newStreamSilenceGuard(clock, timeout, cancelAttempt)
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
@@ -863,7 +894,7 @@ func guardedStream(
|
||||
streamStart := clock.Now()
|
||||
stream, err := openStream(attemptCtx)
|
||||
if err != nil {
|
||||
err = classifyStartupTimeout(attemptCtx, provider, err)
|
||||
err = classifyStreamSilenceTimeout(attemptCtx, provider, err)
|
||||
release()
|
||||
return guardedAttempt{}, err
|
||||
}
|
||||
@@ -877,7 +908,7 @@ func guardedStream(
|
||||
ctx: attemptCtx,
|
||||
stream: fantasy.StreamResponse(func(yield func(fantasy.StreamPart) bool) {
|
||||
for part := range stream {
|
||||
guard.Disarm()
|
||||
guard.Reset()
|
||||
recordTTFT()
|
||||
if !yield(part) {
|
||||
return
|
||||
@@ -886,7 +917,7 @@ func guardedStream(
|
||||
}),
|
||||
release: release,
|
||||
finish: func(err error) error {
|
||||
return classifyStartupTimeout(attemptCtx, provider, err)
|
||||
return classifyStreamSilenceTimeout(attemptCtx, provider, err)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -581,13 +581,13 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for range 128 {
|
||||
var cancels atomic.Int32
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStartupTimeout) {
|
||||
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStreamSilenceTimeout) {
|
||||
cancels.Add(1)
|
||||
}
|
||||
})
|
||||
@@ -618,17 +618,17 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
|
||||
defer cancelAttempt(nil)
|
||||
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard.Disarm()
|
||||
guard.onTimeout()
|
||||
|
||||
classified := chaterror.Classify(classifyStartupTimeout(
|
||||
classified := chaterror.Classify(classifyStreamSilenceTimeout(
|
||||
attemptCtx,
|
||||
"openai",
|
||||
xerrors.New("invalid model"),
|
||||
@@ -638,10 +638,10 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
require.Nil(t, context.Cause(attemptCtx))
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -650,7 +650,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -675,10 +675,10 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -694,7 +694,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -710,9 +710,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -728,7 +728,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
t.Run(provider, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -737,7 +737,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -763,10 +763,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -795,10 +795,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -807,7 +807,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -837,10 +837,10 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -856,7 +856,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -872,16 +872,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -890,12 +890,17 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer armTrap.Close()
|
||||
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
|
||||
defer resetTrap.Close()
|
||||
|
||||
attempts := 0
|
||||
retried := false
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
continueStream := make(chan struct{})
|
||||
secondPartYielded := make(chan struct{}, 1)
|
||||
continueToSecond := make(chan struct{})
|
||||
continueToFinish := make(chan struct{})
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
@@ -910,7 +915,29 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueStream:
|
||||
case <-continueToSecond:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeTextDelta,
|
||||
ID: "text-1",
|
||||
Delta: "done",
|
||||
}) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case secondPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueToFinish:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
@@ -920,7 +947,6 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
@@ -936,10 +962,10 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -954,23 +980,130 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
close(continueStream)
|
||||
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
|
||||
close(continueToSecond)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-secondPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for second stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
|
||||
close(continueToFinish)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 1, attempts)
|
||||
require.False(t, retried)
|
||||
}
|
||||
|
||||
func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitLong,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer armTrap.Close()
|
||||
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
|
||||
defer resetTrap.Close()
|
||||
|
||||
attempts := 0
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case firstPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
attemptCause <- context.Cause(ctx)
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
return streamFromParts([]fantasy.StreamPart{{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
}}), nil
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, codersdk.ChatErrorKindStartupTimeout, retries[0].Kind)
|
||||
require.True(t, retries[0].Retryable)
|
||||
require.Equal(t, "openai", retries[0].Provider)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1014,10 +1147,10 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
||||
t.Fatal("expected Run to panic")
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -1026,7 +1159,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -1052,10 +1185,10 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -1071,7 +1204,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -1087,9 +1220,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user