Files
coder/coderd/chatd/chatretry/chatretry_test.go
T
Kyle Carberry 2bdacae5f5 feat(chatd): add LLM stream retry with exponential backoff (#22418)
## Summary

Adds automatic retry with exponential backoff for transient LLM errors
during chat streaming and title generation. Inspired by
[coder/mux](https://github.com/coder/mux)'s retry mechanism.

## Key Behaviors

- **Infinite retries** with exponential backoff: 1s → 2s → 4s → ... →
60s cap
- **Deterministic delays** (no jitter)
- **Error classification**: retryable (429, 5xx, overloaded, rate limit,
network errors) vs non-retryable (auth, quota, context exceeded, model
not found, canceled)
- **Retry status published to SSE stream** so frontend can show
"Retrying in Xs..." UI
- **Title generation** retries silently (best-effort, nil onRetry
callback)

## New Package: `coderd/chatd/chatretry/`

| File | Purpose |
|------|---------|
| `classify.go` | `IsRetryable(err)` and `StatusCodeRetryable(code)` |
| `backoff.go` | `Delay(attempt)` — exponential doubling with 60s cap |
| `retry.go` | `Retry(ctx, fn, onRetry)` — infinite loop with
context-aware timer |

## Test Helpers: `coderd/chatd/chattest/errors.go`

Anthropic and OpenAI error response builders for use in chattest
providers:
- `AnthropicErrorResponse()`, `AnthropicOverloadedResponse()`,
`AnthropicRateLimitResponse()`
- `OpenAIErrorResponse()`, `OpenAIRateLimitResponse()`,
`OpenAIServerErrorResponse()`

## SDK Changes: `codersdk/chats.go`

- New `ChatStreamEventType: "retry"`
- New `ChatStreamRetry` struct with `Attempt`, `DelayMs`, `Error`,
`RetryingAt` fields
- TypeScript types auto-generated

## Changed Files

- `coderd/chatd/chatloop/chatloop.go` — wraps `agent.Stream()` in
`chatretry.Retry()`
- `coderd/chatd/chatd.go` — publishes retry events to SSE stream with
logging
- `coderd/chatd/title.go` — wraps `model.Generate()` in silent retry
- `coderd/chatd/chattest/anthropic.go` / `openai.go` — error injection
support

## Tests

42 tests covering classification (33), backoff (9), and retry scenarios
(8).
2026-02-27 18:34:33 -05:00

453 lines
9.8 KiB
Go

package chatretry_test
import (
"context"
"errors"
"fmt"
"sync/atomic"
"testing"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatretry"
)
func TestIsRetryable(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
retryable bool
}{
// Retryable errors.
{
name: "Overloaded",
err: xerrors.New("model is overloaded, please try again"),
retryable: true,
},
{
name: "RateLimit",
err: xerrors.New("rate limit exceeded"),
retryable: true,
},
{
name: "RateLimitUnderscore",
err: xerrors.New("rate_limit: too many requests"),
retryable: true,
},
{
name: "TooManyRequests",
err: xerrors.New("too many requests"),
retryable: true,
},
{
name: "HTTP429InMessage",
err: xerrors.New("received status 429 from upstream"),
retryable: false, // "429" alone is not a pattern; needs matching text.
},
{
name: "HTTP529InMessage",
err: xerrors.New("received status 529 from upstream"),
retryable: true,
},
{
name: "ServerError500",
err: xerrors.New("status 500: internal server error"),
retryable: true,
},
{
name: "ServerErrorGeneric",
err: xerrors.New("server error"),
retryable: true,
},
{
name: "ConnectionReset",
err: xerrors.New("read tcp: connection reset by peer"),
retryable: true,
},
{
name: "ConnectionRefused",
err: xerrors.New("dial tcp: connection refused"),
retryable: true,
},
{
name: "EOF",
err: xerrors.New("unexpected EOF"),
retryable: true,
},
{
name: "BrokenPipe",
err: xerrors.New("write: broken pipe"),
retryable: true,
},
{
name: "NetworkTimeout",
err: xerrors.New("i/o timeout"),
retryable: true,
},
{
name: "ServiceUnavailable",
err: xerrors.New("service unavailable"),
retryable: true,
},
{
name: "Unavailable",
err: xerrors.New("the service is currently unavailable"),
retryable: true,
},
{
name: "Status502",
err: xerrors.New("status 502: bad gateway"),
retryable: true,
},
{
name: "Status503",
err: xerrors.New("status 503"),
retryable: true,
},
// Non-retryable errors.
{
name: "Nil",
err: nil,
retryable: false,
},
{
name: "ContextCanceled",
err: context.Canceled,
retryable: false,
},
{
name: "ContextCanceledWrapped",
err: xerrors.Errorf("operation failed: %w", context.Canceled),
retryable: false,
},
{
name: "ContextCanceledMessage",
err: xerrors.New("context canceled"),
retryable: false,
},
{
name: "ContextDeadlineExceeded",
err: xerrors.New("context deadline exceeded"),
retryable: false,
},
{
name: "Authentication",
err: xerrors.New("authentication failed"),
retryable: false,
},
{
name: "Unauthorized",
err: xerrors.New("401 Unauthorized"),
retryable: false,
},
{
name: "Forbidden",
err: xerrors.New("403 Forbidden"),
retryable: false,
},
{
name: "InvalidAPIKey",
err: xerrors.New("invalid api key"),
retryable: false,
},
{
name: "InvalidAPIKeyUnderscore",
err: xerrors.New("invalid_api_key"),
retryable: false,
},
{
name: "InvalidModel",
err: xerrors.New("invalid model: gpt-5-turbo"),
retryable: false,
},
{
name: "ModelNotFound",
err: xerrors.New("model not found"),
retryable: false,
},
{
name: "ModelNotFoundUnderscore",
err: xerrors.New("model_not_found"),
retryable: false,
},
{
name: "ContextLengthExceeded",
err: xerrors.New("context length exceeded"),
retryable: false,
},
{
name: "ContextExceededUnderscore",
err: xerrors.New("context_exceeded"),
retryable: false,
},
{
name: "MaximumContextLength",
err: xerrors.New("maximum context length"),
retryable: false,
},
{
name: "QuotaExceeded",
err: xerrors.New("quota exceeded"),
retryable: false,
},
{
name: "BillingError",
err: xerrors.New("billing issue: payment required"),
retryable: false,
},
// Wrapped errors preserve retryability.
{
name: "WrappedRetryable",
err: xerrors.Errorf("provider call failed: %w", xerrors.New("service unavailable")),
retryable: true,
},
{
name: "WrappedNonRetryable",
err: xerrors.Errorf("provider call failed: %w", xerrors.New("invalid api key")),
retryable: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatretry.IsRetryable(tt.err)
if got != tt.retryable {
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.retryable)
}
})
}
}
func TestStatusCodeRetryable(t *testing.T) {
t.Parallel()
tests := []struct {
code int
retryable bool
}{
{429, true},
{500, true},
{502, true},
{503, true},
{529, true},
{200, false},
{400, false},
{401, false},
{403, false},
{404, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Status%d", tt.code), func(t *testing.T) {
t.Parallel()
got := chatretry.StatusCodeRetryable(tt.code)
if got != tt.retryable {
t.Errorf("StatusCodeRetryable(%d) = %v, want %v", tt.code, got, tt.retryable)
}
})
}
}
func TestDelay(t *testing.T) {
t.Parallel()
tests := []struct {
attempt int
want time.Duration
}{
{0, 1 * time.Second},
{1, 2 * time.Second},
{2, 4 * time.Second},
{3, 8 * time.Second},
{4, 16 * time.Second},
{5, 32 * time.Second},
{6, 60 * time.Second}, // Capped at MaxDelay.
{10, 60 * time.Second}, // Still capped.
{100, 60 * time.Second},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Attempt%d", tt.attempt), func(t *testing.T) {
t.Parallel()
got := chatretry.Delay(tt.attempt)
if got != tt.want {
t.Errorf("Delay(%d) = %v, want %v", tt.attempt, got, tt.want)
}
})
}
}
func TestRetry_SuccessOnFirstTry(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 1 {
t.Fatalf("expected fn called once, got %d", calls)
}
}
func TestRetry_TransientThenSuccess(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls == 1 {
return xerrors.New("service unavailable")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 2 {
t.Fatalf("expected fn called twice, got %d", calls)
}
}
func TestRetry_MultipleTransientThenSuccess(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls <= 3 {
return xerrors.New("overloaded")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if calls != 4 {
t.Fatalf("expected fn called 4 times, got %d", calls)
}
}
func TestRetry_NonRetryableError(t *testing.T) {
t.Parallel()
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
return xerrors.New("invalid api key")
}, nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if err.Error() != "invalid api key" {
t.Fatalf("expected 'invalid api key', got %q", err.Error())
}
if calls != 1 {
t.Fatalf("expected fn called once, got %d", calls)
}
}
func TestRetry_ContextCanceledDuringWait(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
calls := 0
err := chatretry.Retry(ctx, func(_ context.Context) error {
calls++
// Cancel after the first retryable error so the wait
// select picks up the cancellation.
if calls == 1 {
cancel()
}
return xerrors.New("overloaded")
}, nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestRetry_ContextCanceledDuringFn(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
err := chatretry.Retry(ctx, func(_ context.Context) error {
cancel()
// Return a retryable error; the loop should detect that
// ctx is done and return the context error.
return xerrors.New("overloaded")
}, nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestRetry_OnRetryCalledWithCorrectArgs(t *testing.T) {
t.Parallel()
type retryRecord struct {
attempt int
errMsg string
delay time.Duration
}
var records []retryRecord
calls := 0
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
calls++
if calls <= 2 {
return xerrors.New("rate limit exceeded")
}
return nil
}, func(attempt int, err error, delay time.Duration) {
records = append(records, retryRecord{
attempt: attempt,
errMsg: err.Error(),
delay: delay,
})
})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if len(records) != 2 {
t.Fatalf("expected 2 onRetry calls, got %d", len(records))
}
if records[0].attempt != 1 {
t.Errorf("first onRetry attempt = %d, want 1", records[0].attempt)
}
if records[1].attempt != 2 {
t.Errorf("second onRetry attempt = %d, want 2", records[1].attempt)
}
if records[0].errMsg != "rate limit exceeded" {
t.Errorf("first onRetry error = %q, want 'rate limit exceeded'", records[0].errMsg)
}
}
func TestRetry_OnRetryNilDoesNotPanic(t *testing.T) {
t.Parallel()
var calls atomic.Int32
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
if calls.Add(1) == 1 {
return xerrors.New("overloaded")
}
return nil
}, nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
}