mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
2bdacae5f5
## Summary Adds automatic retry with exponential backoff for transient LLM errors during chat streaming and title generation. Inspired by [coder/mux](https://github.com/coder/mux)'s retry mechanism. ## Key Behaviors - **Infinite retries** with exponential backoff: 1s → 2s → 4s → ... → 60s cap - **Deterministic delays** (no jitter) - **Error classification**: retryable (429, 5xx, overloaded, rate limit, network errors) vs non-retryable (auth, quota, context exceeded, model not found, canceled) - **Retry status published to SSE stream** so frontend can show "Retrying in Xs..." UI - **Title generation** retries silently (best-effort, nil onRetry callback) ## New Package: `coderd/chatd/chatretry/` | File | Purpose | |------|---------| | `classify.go` | `IsRetryable(err)` and `StatusCodeRetryable(code)` | | `backoff.go` | `Delay(attempt)` — exponential doubling with 60s cap | | `retry.go` | `Retry(ctx, fn, onRetry)` — infinite loop with context-aware timer | ## Test Helpers: `coderd/chatd/chattest/errors.go` Anthropic and OpenAI error response builders for use in chattest providers: - `AnthropicErrorResponse()`, `AnthropicOverloadedResponse()`, `AnthropicRateLimitResponse()` - `OpenAIErrorResponse()`, `OpenAIRateLimitResponse()`, `OpenAIServerErrorResponse()` ## SDK Changes: `codersdk/chats.go` - New `ChatStreamEventType: "retry"` - New `ChatStreamRetry` struct with `Attempt`, `DelayMs`, `Error`, `RetryingAt` fields - TypeScript types auto-generated ## Changed Files - `coderd/chatd/chatloop/chatloop.go` — wraps `agent.Stream()` in `chatretry.Retry()` - `coderd/chatd/chatd.go` — publishes retry events to SSE stream with logging - `coderd/chatd/title.go` — wraps `model.Generate()` in silent retry - `coderd/chatd/chattest/anthropic.go` / `openai.go` — error injection support ## Tests 42 tests covering classification (33), backoff (9), and retry scenarios (8).
453 lines
9.8 KiB
Go
453 lines
9.8 KiB
Go
package chatretry_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
|
)
|
|
|
|
func TestIsRetryable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
retryable bool
|
|
}{
|
|
// Retryable errors.
|
|
{
|
|
name: "Overloaded",
|
|
err: xerrors.New("model is overloaded, please try again"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "RateLimit",
|
|
err: xerrors.New("rate limit exceeded"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "RateLimitUnderscore",
|
|
err: xerrors.New("rate_limit: too many requests"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "TooManyRequests",
|
|
err: xerrors.New("too many requests"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "HTTP429InMessage",
|
|
err: xerrors.New("received status 429 from upstream"),
|
|
retryable: false, // "429" alone is not a pattern; needs matching text.
|
|
},
|
|
{
|
|
name: "HTTP529InMessage",
|
|
err: xerrors.New("received status 529 from upstream"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "ServerError500",
|
|
err: xerrors.New("status 500: internal server error"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "ServerErrorGeneric",
|
|
err: xerrors.New("server error"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "ConnectionReset",
|
|
err: xerrors.New("read tcp: connection reset by peer"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "ConnectionRefused",
|
|
err: xerrors.New("dial tcp: connection refused"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "EOF",
|
|
err: xerrors.New("unexpected EOF"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "BrokenPipe",
|
|
err: xerrors.New("write: broken pipe"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "NetworkTimeout",
|
|
err: xerrors.New("i/o timeout"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "ServiceUnavailable",
|
|
err: xerrors.New("service unavailable"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "Unavailable",
|
|
err: xerrors.New("the service is currently unavailable"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "Status502",
|
|
err: xerrors.New("status 502: bad gateway"),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "Status503",
|
|
err: xerrors.New("status 503"),
|
|
retryable: true,
|
|
},
|
|
|
|
// Non-retryable errors.
|
|
{
|
|
name: "Nil",
|
|
err: nil,
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextCanceled",
|
|
err: context.Canceled,
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextCanceledWrapped",
|
|
err: xerrors.Errorf("operation failed: %w", context.Canceled),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextCanceledMessage",
|
|
err: xerrors.New("context canceled"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextDeadlineExceeded",
|
|
err: xerrors.New("context deadline exceeded"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "Authentication",
|
|
err: xerrors.New("authentication failed"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "Unauthorized",
|
|
err: xerrors.New("401 Unauthorized"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "Forbidden",
|
|
err: xerrors.New("403 Forbidden"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "InvalidAPIKey",
|
|
err: xerrors.New("invalid api key"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "InvalidAPIKeyUnderscore",
|
|
err: xerrors.New("invalid_api_key"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "InvalidModel",
|
|
err: xerrors.New("invalid model: gpt-5-turbo"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ModelNotFound",
|
|
err: xerrors.New("model not found"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ModelNotFoundUnderscore",
|
|
err: xerrors.New("model_not_found"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextLengthExceeded",
|
|
err: xerrors.New("context length exceeded"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "ContextExceededUnderscore",
|
|
err: xerrors.New("context_exceeded"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "MaximumContextLength",
|
|
err: xerrors.New("maximum context length"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "QuotaExceeded",
|
|
err: xerrors.New("quota exceeded"),
|
|
retryable: false,
|
|
},
|
|
{
|
|
name: "BillingError",
|
|
err: xerrors.New("billing issue: payment required"),
|
|
retryable: false,
|
|
},
|
|
|
|
// Wrapped errors preserve retryability.
|
|
{
|
|
name: "WrappedRetryable",
|
|
err: xerrors.Errorf("provider call failed: %w", xerrors.New("service unavailable")),
|
|
retryable: true,
|
|
},
|
|
{
|
|
name: "WrappedNonRetryable",
|
|
err: xerrors.Errorf("provider call failed: %w", xerrors.New("invalid api key")),
|
|
retryable: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
got := chatretry.IsRetryable(tt.err)
|
|
if got != tt.retryable {
|
|
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.retryable)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatusCodeRetryable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
code int
|
|
retryable bool
|
|
}{
|
|
{429, true},
|
|
{500, true},
|
|
{502, true},
|
|
{503, true},
|
|
{529, true},
|
|
{200, false},
|
|
{400, false},
|
|
{401, false},
|
|
{403, false},
|
|
{404, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(fmt.Sprintf("Status%d", tt.code), func(t *testing.T) {
|
|
t.Parallel()
|
|
got := chatretry.StatusCodeRetryable(tt.code)
|
|
if got != tt.retryable {
|
|
t.Errorf("StatusCodeRetryable(%d) = %v, want %v", tt.code, got, tt.retryable)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDelay(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
attempt int
|
|
want time.Duration
|
|
}{
|
|
{0, 1 * time.Second},
|
|
{1, 2 * time.Second},
|
|
{2, 4 * time.Second},
|
|
{3, 8 * time.Second},
|
|
{4, 16 * time.Second},
|
|
{5, 32 * time.Second},
|
|
{6, 60 * time.Second}, // Capped at MaxDelay.
|
|
{10, 60 * time.Second}, // Still capped.
|
|
{100, 60 * time.Second},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(fmt.Sprintf("Attempt%d", tt.attempt), func(t *testing.T) {
|
|
t.Parallel()
|
|
got := chatretry.Delay(tt.attempt)
|
|
if got != tt.want {
|
|
t.Errorf("Delay(%d) = %v, want %v", tt.attempt, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRetry_SuccessOnFirstTry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
calls++
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("expected fn called once, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetry_TransientThenSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
calls++
|
|
if calls == 1 {
|
|
return xerrors.New("service unavailable")
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if calls != 2 {
|
|
t.Fatalf("expected fn called twice, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetry_MultipleTransientThenSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
calls++
|
|
if calls <= 3 {
|
|
return xerrors.New("overloaded")
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if calls != 4 {
|
|
t.Fatalf("expected fn called 4 times, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetry_NonRetryableError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
calls++
|
|
return xerrors.New("invalid api key")
|
|
}, nil)
|
|
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if err.Error() != "invalid api key" {
|
|
t.Fatalf("expected 'invalid api key', got %q", err.Error())
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("expected fn called once, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestRetry_ContextCanceledDuringWait(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(ctx, func(_ context.Context) error {
|
|
calls++
|
|
// Cancel after the first retryable error so the wait
|
|
// select picks up the cancellation.
|
|
if calls == 1 {
|
|
cancel()
|
|
}
|
|
return xerrors.New("overloaded")
|
|
}, nil)
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("expected context.Canceled, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRetry_ContextCanceledDuringFn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
err := chatretry.Retry(ctx, func(_ context.Context) error {
|
|
cancel()
|
|
// Return a retryable error; the loop should detect that
|
|
// ctx is done and return the context error.
|
|
return xerrors.New("overloaded")
|
|
}, nil)
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("expected context.Canceled, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRetry_OnRetryCalledWithCorrectArgs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type retryRecord struct {
|
|
attempt int
|
|
errMsg string
|
|
delay time.Duration
|
|
}
|
|
var records []retryRecord
|
|
|
|
calls := 0
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
calls++
|
|
if calls <= 2 {
|
|
return xerrors.New("rate limit exceeded")
|
|
}
|
|
return nil
|
|
}, func(attempt int, err error, delay time.Duration) {
|
|
records = append(records, retryRecord{
|
|
attempt: attempt,
|
|
errMsg: err.Error(),
|
|
delay: delay,
|
|
})
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
if len(records) != 2 {
|
|
t.Fatalf("expected 2 onRetry calls, got %d", len(records))
|
|
}
|
|
if records[0].attempt != 1 {
|
|
t.Errorf("first onRetry attempt = %d, want 1", records[0].attempt)
|
|
}
|
|
if records[1].attempt != 2 {
|
|
t.Errorf("second onRetry attempt = %d, want 2", records[1].attempt)
|
|
}
|
|
if records[0].errMsg != "rate limit exceeded" {
|
|
t.Errorf("first onRetry error = %q, want 'rate limit exceeded'", records[0].errMsg)
|
|
}
|
|
}
|
|
|
|
func TestRetry_OnRetryNilDoesNotPanic(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var calls atomic.Int32
|
|
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
|
if calls.Add(1) == 1 {
|
|
return xerrors.New("overloaded")
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("expected nil error, got %v", err)
|
|
}
|
|
}
|