mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
Cherry-pick of https://github.com/coder/coder/pull/25782
Original PR: #25782 — fix(coderd/x/chatd/chatloop): use stream silence
timeout
Merge commit: 7e2f7198dd
Requested by: @ethanndickson
Co-authored-by: Ethan <ethanndickson@gmail.com>
This commit is contained in:
committed by
GitHub
parent
eb8b062b1d
commit
85d39b3dbe
@@ -39,10 +39,11 @@ const (
|
||||
// prevents infinite compaction loops when the model keeps
|
||||
// hitting the context limit after summarization.
|
||||
maxCompactionRetries = 3
|
||||
// defaultStartupTimeout bounds how long an individual
|
||||
// model attempt may spend starting to respond before
|
||||
// defaultStreamSilenceTimeout bounds how long an individual
|
||||
// model attempt may go without receiving a stream part before
|
||||
// the attempt is canceled and retried.
|
||||
defaultStartupTimeout = 60 * time.Second
|
||||
defaultStreamSilenceTimeout = 10 * time.Minute
|
||||
streamSilenceGuardTimerTag = "streamSilenceGuard"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -53,8 +54,8 @@ var (
|
||||
// the run should terminate cleanly after persistence.
|
||||
ErrStopAfterTool = xerrors.New("stop after tool")
|
||||
|
||||
errStartupTimeout = xerrors.New(
|
||||
"chat response did not start before the startup timeout",
|
||||
errStreamSilenceTimeout = xerrors.New(
|
||||
"chat stream was silent for longer than the configured timeout",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -114,14 +115,14 @@ type RunOptions struct {
|
||||
Messages []fantasy.Message
|
||||
Tools []fantasy.AgentTool
|
||||
MaxSteps int
|
||||
// StartupTimeout bounds how long each model attempt may
|
||||
// spend opening the provider stream and waiting for its
|
||||
// first stream part before the attempt is canceled and
|
||||
// retried. Zero uses the production default.
|
||||
StartupTimeout time.Duration
|
||||
// Clock creates startup guard timers. In production use a
|
||||
// real clock; tests can inject quartz.NewMock(t) to make
|
||||
// startup timeout behavior deterministic.
|
||||
// StreamSilenceTimeout bounds how long each model attempt
|
||||
// may go without receiving a stream part before the
|
||||
// attempt is canceled and retried. Zero uses the
|
||||
// production default.
|
||||
StreamSilenceTimeout time.Duration
|
||||
// Clock creates stream silence guard timers. In production
|
||||
// use a real clock; tests can inject quartz.NewMock(t) to
|
||||
// make timeout behavior deterministic.
|
||||
Clock quartz.Clock
|
||||
|
||||
ActiveTools []string
|
||||
@@ -364,8 +365,8 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
if opts.MaxSteps <= 0 {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
if opts.StartupTimeout <= 0 {
|
||||
opts.StartupTimeout = defaultStartupTimeout
|
||||
if opts.StreamSilenceTimeout <= 0 {
|
||||
opts.StreamSilenceTimeout = defaultStreamSilenceTimeout
|
||||
}
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
@@ -468,7 +469,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
provider,
|
||||
modelName,
|
||||
opts.Clock,
|
||||
opts.StartupTimeout,
|
||||
opts.StreamSilenceTimeout,
|
||||
func(attemptCtx context.Context) (fantasy.StreamResponse, error) {
|
||||
return opts.Model.Stream(attemptCtx, call)
|
||||
},
|
||||
@@ -782,9 +783,9 @@ func prepareMessagesForRequest(
|
||||
return canonical, prompt, nil
|
||||
}
|
||||
|
||||
// guardedAttempt owns an attempt-scoped context and startup guard
|
||||
// guardedAttempt owns an attempt-scoped context and silence guard
|
||||
// around a provider stream. release is idempotent and frees the
|
||||
// attempt-scoped timer/context. finish canonicalizes startup timeout
|
||||
// attempt-scoped timer/context. finish canonicalizes silence timeout
|
||||
// errors before the retry loop classifies them.
|
||||
type guardedAttempt struct {
|
||||
ctx context.Context
|
||||
@@ -793,47 +794,77 @@ type guardedAttempt struct {
|
||||
finish func(error) error
|
||||
}
|
||||
|
||||
// startupGuard arbitrates whether an attempt times out during
|
||||
// stream startup. Exactly one outcome wins: the timer cancels
|
||||
// the attempt, or the first-part path disarms the timer.
|
||||
type startupGuard struct {
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
once sync.Once
|
||||
// streamSilenceGuard arbitrates whether an attempt times out while
|
||||
// waiting for the next stream part. Exactly one outcome wins: the
|
||||
// timer cancels the attempt, or release disarms the timer.
|
||||
type streamSilenceGuard struct {
|
||||
mu sync.Mutex
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
timeout time.Duration
|
||||
settled bool
|
||||
}
|
||||
|
||||
func newStartupGuard(
|
||||
func newStreamSilenceGuard(
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
cancel context.CancelCauseFunc,
|
||||
) *startupGuard {
|
||||
guard := &startupGuard{cancel: cancel}
|
||||
guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard")
|
||||
) *streamSilenceGuard {
|
||||
guard := &streamSilenceGuard{
|
||||
cancel: cancel,
|
||||
timeout: timeout,
|
||||
}
|
||||
guard.timer = clock.AfterFunc(
|
||||
timeout,
|
||||
guard.onTimeout,
|
||||
streamSilenceGuardTimerTag,
|
||||
)
|
||||
return guard
|
||||
}
|
||||
|
||||
func (g *startupGuard) onTimeout() {
|
||||
g.once.Do(func() {
|
||||
g.cancel(errStartupTimeout)
|
||||
})
|
||||
func (g *streamSilenceGuard) settle() bool {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.settled {
|
||||
return false
|
||||
}
|
||||
g.settled = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *startupGuard) Disarm() {
|
||||
g.once.Do(func() {
|
||||
g.timer.Stop()
|
||||
})
|
||||
func (g *streamSilenceGuard) onTimeout() {
|
||||
if !g.settle() {
|
||||
return
|
||||
}
|
||||
g.cancel(errStreamSilenceTimeout)
|
||||
}
|
||||
|
||||
func classifyStartupTimeout(
|
||||
func (g *streamSilenceGuard) Reset() {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.settled {
|
||||
return
|
||||
}
|
||||
g.timer.Reset(g.timeout, streamSilenceGuardTimerTag)
|
||||
}
|
||||
|
||||
func (g *streamSilenceGuard) Disarm() {
|
||||
if !g.settle() {
|
||||
return
|
||||
}
|
||||
g.timer.Stop()
|
||||
}
|
||||
|
||||
func classifyStreamSilenceTimeout(
|
||||
attemptCtx context.Context,
|
||||
provider string,
|
||||
err error,
|
||||
) error {
|
||||
if !errors.Is(context.Cause(attemptCtx), errStartupTimeout) {
|
||||
if !errors.Is(context.Cause(attemptCtx), errStreamSilenceTimeout) {
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
err = errStartupTimeout
|
||||
err = errStreamSilenceTimeout
|
||||
}
|
||||
return chaterror.WithClassification(err, chaterror.ClassifiedError{
|
||||
Kind: codersdk.ChatErrorKindStartupTimeout,
|
||||
@@ -851,7 +882,7 @@ func guardedStream(
|
||||
metrics *Metrics,
|
||||
) (guardedAttempt, error) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(parent)
|
||||
guard := newStartupGuard(clock, timeout, cancelAttempt)
|
||||
guard := newStreamSilenceGuard(clock, timeout, cancelAttempt)
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
@@ -863,7 +894,7 @@ func guardedStream(
|
||||
streamStart := clock.Now()
|
||||
stream, err := openStream(attemptCtx)
|
||||
if err != nil {
|
||||
err = classifyStartupTimeout(attemptCtx, provider, err)
|
||||
err = classifyStreamSilenceTimeout(attemptCtx, provider, err)
|
||||
release()
|
||||
return guardedAttempt{}, err
|
||||
}
|
||||
@@ -877,7 +908,7 @@ func guardedStream(
|
||||
ctx: attemptCtx,
|
||||
stream: fantasy.StreamResponse(func(yield func(fantasy.StreamPart) bool) {
|
||||
for part := range stream {
|
||||
guard.Disarm()
|
||||
guard.Reset()
|
||||
recordTTFT()
|
||||
if !yield(part) {
|
||||
return
|
||||
@@ -886,7 +917,7 @@ func guardedStream(
|
||||
}),
|
||||
release: release,
|
||||
finish: func(err error) error {
|
||||
return classifyStartupTimeout(attemptCtx, provider, err)
|
||||
return classifyStreamSilenceTimeout(attemptCtx, provider, err)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -581,13 +581,13 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for range 128 {
|
||||
var cancels atomic.Int32
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStartupTimeout) {
|
||||
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStreamSilenceTimeout) {
|
||||
cancels.Add(1)
|
||||
}
|
||||
})
|
||||
@@ -618,17 +618,17 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
|
||||
defer cancelAttempt(nil)
|
||||
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard.Disarm()
|
||||
guard.onTimeout()
|
||||
|
||||
classified := chaterror.Classify(classifyStartupTimeout(
|
||||
classified := chaterror.Classify(classifyStreamSilenceTimeout(
|
||||
attemptCtx,
|
||||
"openai",
|
||||
xerrors.New("invalid model"),
|
||||
@@ -638,10 +638,10 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
require.Nil(t, context.Cause(attemptCtx))
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -650,7 +650,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -675,10 +675,10 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -694,7 +694,7 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -710,9 +710,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -728,7 +728,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
t.Run(provider, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -737,7 +737,7 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -763,10 +763,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -795,10 +795,10 @@ func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -807,7 +807,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -837,10 +837,10 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -856,7 +856,7 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -872,16 +872,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -890,12 +890,17 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer armTrap.Close()
|
||||
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
|
||||
defer resetTrap.Close()
|
||||
|
||||
attempts := 0
|
||||
retried := false
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
continueStream := make(chan struct{})
|
||||
secondPartYielded := make(chan struct{}, 1)
|
||||
continueToSecond := make(chan struct{})
|
||||
continueToFinish := make(chan struct{})
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
@@ -910,7 +915,29 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueStream:
|
||||
case <-continueToSecond:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeTextDelta,
|
||||
ID: "text-1",
|
||||
Delta: "done",
|
||||
}) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case secondPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueToFinish:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
@@ -920,7 +947,6 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
@@ -936,10 +962,10 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -954,23 +980,130 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
close(continueStream)
|
||||
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
|
||||
close(continueToSecond)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-secondPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for second stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(silenceTimeout / 2).MustWait(ctx)
|
||||
close(continueToFinish)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 1, attempts)
|
||||
require.False(t, retried)
|
||||
}
|
||||
|
||||
func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitLong,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer armTrap.Close()
|
||||
resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag)
|
||||
defer resetTrap.Close()
|
||||
|
||||
attempts := 0
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case firstPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
attemptCause <- context.Cause(ctx)
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
return streamFromParts([]fantasy.StreamPart{{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
}}), nil
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
armTrap.MustWait(ctx).MustRelease(ctx)
|
||||
resetTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, codersdk.ChatErrorKindStartupTimeout, retries[0].Kind)
|
||||
require.True(t, retries[0].Retryable)
|
||||
require.Equal(t, "openai", retries[0].Provider)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1014,10 +1147,10 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
||||
t.Fatal("expected Run to panic")
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
const silenceTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
@@ -1026,7 +1159,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
@@ -1052,10 +1185,10 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StreamSilenceTimeout: silenceTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
@@ -1071,7 +1204,7 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
mClock.Advance(silenceTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
@@ -1087,9 +1220,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
require.ErrorIs(t, cause, errStreamSilenceTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
t.Fatal("timed out waiting for silence timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user