mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(chatd): add LLM stream retry with exponential backoff (#22418)
## Summary Adds automatic retry with exponential backoff for transient LLM errors during chat streaming and title generation. Inspired by [coder/mux](https://github.com/coder/mux)'s retry mechanism. ## Key Behaviors - **Infinite retries** with exponential backoff: 1s → 2s → 4s → ... → 60s cap - **Deterministic delays** (no jitter) - **Error classification**: retryable (429, 5xx, overloaded, rate limit, network errors) vs non-retryable (auth, quota, context exceeded, model not found, canceled) - **Retry status published to SSE stream** so frontend can show "Retrying in Xs..." UI - **Title generation** retries silently (best-effort, nil onRetry callback) ## New Package: `coderd/chatd/chatretry/` | File | Purpose | |------|---------| | `classify.go` | `IsRetryable(err)` and `StatusCodeRetryable(code)` | | `backoff.go` | `Delay(attempt)` — exponential doubling with 60s cap | | `retry.go` | `Retry(ctx, fn, onRetry)` — infinite loop with context-aware timer | ## Test Helpers: `coderd/chatd/chattest/errors.go` Anthropic and OpenAI error response builders for use in chattest providers: - `AnthropicErrorResponse()`, `AnthropicOverloadedResponse()`, `AnthropicRateLimitResponse()` - `OpenAIErrorResponse()`, `OpenAIRateLimitResponse()`, `OpenAIServerErrorResponse()` ## SDK Changes: `codersdk/chats.go` - New `ChatStreamEventType: "retry"` - New `ChatStreamRetry` struct with `Attempt`, `DelayMs`, `Error`, `RetryingAt` fields - TypeScript types auto-generated ## Changed Files - `coderd/chatd/chatloop/chatloop.go` — wraps `agent.Stream()` in `chatretry.Retry()` - `coderd/chatd/chatd.go` — publishes retry events to SSE stream with logging - `coderd/chatd/title.go` — wraps `model.Generate()` in silent retry - `coderd/chatd/chattest/anthropic.go` / `openai.go` — error injection support ## Tests 42 tests covering classification (33), backoff (9), and retry scenarios (8).
This commit is contained in:
@@ -2086,6 +2086,24 @@ func (p *Server) runChat(
|
||||
},
|
||||
Compaction: compactionOptions,
|
||||
|
||||
OnRetry: func(attempt int, retryErr error, delay time.Duration) {
|
||||
logger.Warn(ctx, "retrying LLM stream",
|
||||
slog.F("attempt", attempt),
|
||||
slog.F("delay", delay.String()),
|
||||
slog.Error(retryErr),
|
||||
)
|
||||
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeRetry,
|
||||
ChatID: chat.ID,
|
||||
Retry: &codersdk.ChatStreamRetry{
|
||||
Attempt: attempt,
|
||||
DelayMs: delay.Milliseconds(),
|
||||
Error: retryErr.Error(),
|
||||
RetryingAt: time.Now().Add(delay),
|
||||
},
|
||||
})
|
||||
},
|
||||
|
||||
OnInterruptedPersistError: func(err error) {
|
||||
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
||||
},
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -54,6 +56,12 @@ type RunOptions struct {
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
|
||||
// OnRetry is called before each retry attempt when the LLM
|
||||
// stream fails with a retryable error. It provides the attempt
|
||||
// number, error, and backoff delay so callers can publish status
|
||||
// events to connected clients.
|
||||
OnRetry chatretry.OnRetryFn
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
@@ -443,15 +451,39 @@ func Run(ctx context.Context, opts RunOptions) (*fantasy.AgentResult, error) {
|
||||
})
|
||||
}
|
||||
|
||||
result, err := agent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
if persistErr := persistInterruptedStep(); persistErr != nil {
|
||||
if opts.OnInterruptedPersistError != nil {
|
||||
opts.OnInterruptedPersistError(persistErr)
|
||||
var result *fantasy.AgentResult
|
||||
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
|
||||
var streamErr error
|
||||
result, streamErr = agent.Stream(retryCtx, streamCall)
|
||||
if streamErr != nil {
|
||||
// Interrupts are not retryable — propagate them
|
||||
// immediately so processChat can set the correct
|
||||
// status.
|
||||
if errors.Is(streamErr, context.Canceled) &&
|
||||
errors.Is(context.Cause(retryCtx), ErrInterrupted) {
|
||||
if persistErr := persistInterruptedStep(); persistErr != nil {
|
||||
if opts.OnInterruptedPersistError != nil {
|
||||
opts.OnInterruptedPersistError(persistErr)
|
||||
}
|
||||
}
|
||||
// Return ErrInterrupted directly so the retry
|
||||
// loop sees a non-retryable error and stops.
|
||||
return ErrInterrupted
|
||||
}
|
||||
return streamErr
|
||||
}
|
||||
return nil
|
||||
}, func(attempt int, retryErr error, delay time.Duration) {
|
||||
// Reset accumulated draft state from the failed attempt
|
||||
// so the next attempt starts clean.
|
||||
resetStepState()
|
||||
|
||||
if opts.OnRetry != nil {
|
||||
opts.OnRetry(attempt, retryErr, delay)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
return nil, ErrInterrupted
|
||||
}
|
||||
return nil, xerrors.Errorf("stream response: %w", err)
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
// Package chatretry provides retry logic for transient LLM provider
|
||||
// errors. It classifies errors as retryable or permanent and
|
||||
// implements exponential backoff matching the behavior of coder/mux.
|
||||
package chatretry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// InitialDelay is the backoff duration for the first retry
|
||||
// attempt.
|
||||
InitialDelay = 1 * time.Second
|
||||
|
||||
// MaxDelay is the upper bound for the exponential backoff
|
||||
// duration. Matches the cap used in coder/mux.
|
||||
MaxDelay = 60 * time.Second
|
||||
)
|
||||
|
||||
// nonRetryablePatterns are substrings that indicate a permanent error
|
||||
// which should not be retried. These are checked first so that
|
||||
// ambiguous messages (e.g. "bad request: rate limit") are correctly
|
||||
// classified as non-retryable.
|
||||
var nonRetryablePatterns = []string{
|
||||
"context canceled",
|
||||
"context deadline exceeded",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"invalid model",
|
||||
"model not found",
|
||||
"model_not_found",
|
||||
"context length exceeded",
|
||||
"context_exceeded",
|
||||
"maximum context length",
|
||||
"quota",
|
||||
"billing",
|
||||
}
|
||||
|
||||
// retryablePatterns are substrings that indicate a transient error
|
||||
// worth retrying.
|
||||
var retryablePatterns = []string{
|
||||
"overloaded",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"server error",
|
||||
"status 500",
|
||||
"status 502",
|
||||
"status 503",
|
||||
"status 529",
|
||||
"connection reset",
|
||||
"connection refused",
|
||||
"eof",
|
||||
"broken pipe",
|
||||
"timeout",
|
||||
"unavailable",
|
||||
"service unavailable",
|
||||
}
|
||||
|
||||
// IsRetryable determines whether an error from an LLM provider is
|
||||
// transient and worth retrying. It inspects the error message and
|
||||
// any wrapped HTTP status codes for known retryable patterns.
|
||||
func IsRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// context.Canceled is always non-retryable regardless of
|
||||
// wrapping.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(err.Error())
|
||||
|
||||
// Check non-retryable patterns first so they take precedence.
|
||||
for _, p := range nonRetryablePatterns {
|
||||
if strings.Contains(lower, p) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range retryablePatterns {
|
||||
if strings.Contains(lower, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// StatusCodeRetryable returns true for HTTP status codes that
|
||||
// indicate a transient failure worth retrying.
|
||||
func StatusCodeRetryable(code int) bool {
|
||||
switch code {
|
||||
case 429, 500, 502, 503, 529:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Delay returns the backoff duration for the given 0-indexed attempt.
|
||||
// Uses exponential backoff: min(InitialDelay * 2^attempt, MaxDelay).
|
||||
// Matches the backoff curve used in coder/mux.
|
||||
func Delay(attempt int) time.Duration {
|
||||
d := InitialDelay
|
||||
for range attempt {
|
||||
d *= 2
|
||||
if d >= MaxDelay {
|
||||
return MaxDelay
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// RetryFn is the function to retry. It receives a context and returns
|
||||
// an error. The context may be a child of the original with adjusted
|
||||
// deadlines for individual attempts.
|
||||
type RetryFn func(ctx context.Context) error
|
||||
|
||||
// OnRetryFn is called before each retry attempt with the attempt
|
||||
// number (1-indexed), the error that triggered the retry, and the
|
||||
// delay before the next attempt.
|
||||
type OnRetryFn func(attempt int, err error, delay time.Duration)
|
||||
|
||||
// Retry calls fn repeatedly until it succeeds, returns a
|
||||
// non-retryable error, or ctx is canceled. There is no max attempt
|
||||
// limit — retries continue indefinitely with exponential backoff
|
||||
// (capped at 60s), matching the behavior of coder/mux.
|
||||
//
|
||||
// The onRetry callback (if non-nil) is called before each retry
|
||||
// attempt, giving the caller a chance to reset state, log, or
|
||||
// publish status events.
|
||||
func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
var attempt int
|
||||
for {
|
||||
err := fn(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !IsRetryable(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the caller's context is already done, return the
|
||||
// context error so cancellation propagates cleanly.
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
delay := Delay(attempt)
|
||||
|
||||
if onRetry != nil {
|
||||
onRetry(attempt+1, err, delay)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,452 @@
|
||||
package chatretry_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
)
|
||||
|
||||
func TestIsRetryable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
retryable bool
|
||||
}{
|
||||
// Retryable errors.
|
||||
{
|
||||
name: "Overloaded",
|
||||
err: xerrors.New("model is overloaded, please try again"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "RateLimit",
|
||||
err: xerrors.New("rate limit exceeded"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "RateLimitUnderscore",
|
||||
err: xerrors.New("rate_limit: too many requests"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "TooManyRequests",
|
||||
err: xerrors.New("too many requests"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP429InMessage",
|
||||
err: xerrors.New("received status 429 from upstream"),
|
||||
retryable: false, // "429" alone is not a pattern; needs matching text.
|
||||
},
|
||||
{
|
||||
name: "HTTP529InMessage",
|
||||
err: xerrors.New("received status 529 from upstream"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "ServerError500",
|
||||
err: xerrors.New("status 500: internal server error"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "ServerErrorGeneric",
|
||||
err: xerrors.New("server error"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "ConnectionReset",
|
||||
err: xerrors.New("read tcp: connection reset by peer"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "ConnectionRefused",
|
||||
err: xerrors.New("dial tcp: connection refused"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "EOF",
|
||||
err: xerrors.New("unexpected EOF"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "BrokenPipe",
|
||||
err: xerrors.New("write: broken pipe"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "NetworkTimeout",
|
||||
err: xerrors.New("i/o timeout"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "ServiceUnavailable",
|
||||
err: xerrors.New("service unavailable"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "Unavailable",
|
||||
err: xerrors.New("the service is currently unavailable"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "Status502",
|
||||
err: xerrors.New("status 502: bad gateway"),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "Status503",
|
||||
err: xerrors.New("status 503"),
|
||||
retryable: true,
|
||||
},
|
||||
|
||||
// Non-retryable errors.
|
||||
{
|
||||
name: "Nil",
|
||||
err: nil,
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextCanceled",
|
||||
err: context.Canceled,
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextCanceledWrapped",
|
||||
err: xerrors.Errorf("operation failed: %w", context.Canceled),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextCanceledMessage",
|
||||
err: xerrors.New("context canceled"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextDeadlineExceeded",
|
||||
err: xerrors.New("context deadline exceeded"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "Authentication",
|
||||
err: xerrors.New("authentication failed"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "Unauthorized",
|
||||
err: xerrors.New("401 Unauthorized"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "Forbidden",
|
||||
err: xerrors.New("403 Forbidden"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidAPIKey",
|
||||
err: xerrors.New("invalid api key"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidAPIKeyUnderscore",
|
||||
err: xerrors.New("invalid_api_key"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidModel",
|
||||
err: xerrors.New("invalid model: gpt-5-turbo"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ModelNotFound",
|
||||
err: xerrors.New("model not found"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ModelNotFoundUnderscore",
|
||||
err: xerrors.New("model_not_found"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextLengthExceeded",
|
||||
err: xerrors.New("context length exceeded"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "ContextExceededUnderscore",
|
||||
err: xerrors.New("context_exceeded"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "MaximumContextLength",
|
||||
err: xerrors.New("maximum context length"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "QuotaExceeded",
|
||||
err: xerrors.New("quota exceeded"),
|
||||
retryable: false,
|
||||
},
|
||||
{
|
||||
name: "BillingError",
|
||||
err: xerrors.New("billing issue: payment required"),
|
||||
retryable: false,
|
||||
},
|
||||
|
||||
// Wrapped errors preserve retryability.
|
||||
{
|
||||
name: "WrappedRetryable",
|
||||
err: xerrors.Errorf("provider call failed: %w", xerrors.New("service unavailable")),
|
||||
retryable: true,
|
||||
},
|
||||
{
|
||||
name: "WrappedNonRetryable",
|
||||
err: xerrors.Errorf("provider call failed: %w", xerrors.New("invalid api key")),
|
||||
retryable: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatretry.IsRetryable(tt.err)
|
||||
if got != tt.retryable {
|
||||
t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.retryable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCodeRetryable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
code int
|
||||
retryable bool
|
||||
}{
|
||||
{429, true},
|
||||
{500, true},
|
||||
{502, true},
|
||||
{503, true},
|
||||
{529, true},
|
||||
{200, false},
|
||||
{400, false},
|
||||
{401, false},
|
||||
{403, false},
|
||||
{404, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("Status%d", tt.code), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatretry.StatusCodeRetryable(tt.code)
|
||||
if got != tt.retryable {
|
||||
t.Errorf("StatusCodeRetryable(%d) = %v, want %v", tt.code, got, tt.retryable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
attempt int
|
||||
want time.Duration
|
||||
}{
|
||||
{0, 1 * time.Second},
|
||||
{1, 2 * time.Second},
|
||||
{2, 4 * time.Second},
|
||||
{3, 8 * time.Second},
|
||||
{4, 16 * time.Second},
|
||||
{5, 32 * time.Second},
|
||||
{6, 60 * time.Second}, // Capped at MaxDelay.
|
||||
{10, 60 * time.Second}, // Still capped.
|
||||
{100, 60 * time.Second},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("Attempt%d", tt.attempt), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatretry.Delay(tt.attempt)
|
||||
if got != tt.want {
|
||||
t.Errorf("Delay(%d) = %v, want %v", tt.attempt, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_SuccessOnFirstTry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
calls++
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected fn called once, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_TransientThenSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
return xerrors.New("service unavailable")
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected fn called twice, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_MultipleTransientThenSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
calls++
|
||||
if calls <= 3 {
|
||||
return xerrors.New("overloaded")
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if calls != 4 {
|
||||
t.Fatalf("expected fn called 4 times, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_NonRetryableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
calls++
|
||||
return xerrors.New("invalid api key")
|
||||
}, nil)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if err.Error() != "invalid api key" {
|
||||
t.Fatalf("expected 'invalid api key', got %q", err.Error())
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected fn called once, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_ContextCanceledDuringWait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(ctx, func(_ context.Context) error {
|
||||
calls++
|
||||
// Cancel after the first retryable error so the wait
|
||||
// select picks up the cancellation.
|
||||
if calls == 1 {
|
||||
cancel()
|
||||
}
|
||||
return xerrors.New("overloaded")
|
||||
}, nil)
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_ContextCanceledDuringFn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
err := chatretry.Retry(ctx, func(_ context.Context) error {
|
||||
cancel()
|
||||
// Return a retryable error; the loop should detect that
|
||||
// ctx is done and return the context error.
|
||||
return xerrors.New("overloaded")
|
||||
}, nil)
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_OnRetryCalledWithCorrectArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type retryRecord struct {
|
||||
attempt int
|
||||
errMsg string
|
||||
delay time.Duration
|
||||
}
|
||||
var records []retryRecord
|
||||
|
||||
calls := 0
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
calls++
|
||||
if calls <= 2 {
|
||||
return xerrors.New("rate limit exceeded")
|
||||
}
|
||||
return nil
|
||||
}, func(attempt int, err error, delay time.Duration) {
|
||||
records = append(records, retryRecord{
|
||||
attempt: attempt,
|
||||
errMsg: err.Error(),
|
||||
delay: delay,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("expected 2 onRetry calls, got %d", len(records))
|
||||
}
|
||||
if records[0].attempt != 1 {
|
||||
t.Errorf("first onRetry attempt = %d, want 1", records[0].attempt)
|
||||
}
|
||||
if records[1].attempt != 2 {
|
||||
t.Errorf("second onRetry attempt = %d, want 2", records[1].attempt)
|
||||
}
|
||||
if records[0].errMsg != "rate limit exceeded" {
|
||||
t.Errorf("first onRetry error = %q, want 'rate limit exceeded'", records[0].errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry_OnRetryNilDoesNotPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls atomic.Int32
|
||||
err := chatretry.Retry(context.Background(), func(_ context.Context) error {
|
||||
if calls.Add(1) == 1 {
|
||||
return xerrors.New("overloaded")
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ type AnthropicHandler func(req *AnthropicRequest) AnthropicResponse
|
||||
type AnthropicResponse struct {
|
||||
StreamingChunks <-chan AnthropicChunk
|
||||
Response *AnthropicMessage
|
||||
Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON.
|
||||
}
|
||||
|
||||
// AnthropicRequest represents an Anthropic messages request.
|
||||
@@ -141,6 +142,11 @@ func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorResponse describes an HTTP error that a test server should return
|
||||
// instead of a normal streaming or JSON response.
|
||||
type ErrorResponse struct {
|
||||
StatusCode int
|
||||
Type string
|
||||
Message string
|
||||
}
|
||||
|
||||
// writeErrorResponse writes a JSON error response matching the common
|
||||
// provider error format used by both Anthropic and OpenAI.
|
||||
func writeErrorResponse(w http.ResponseWriter, errResp *ErrorResponse) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(errResp.StatusCode)
|
||||
body := map[string]interface{}{
|
||||
"error": map[string]interface{}{
|
||||
"type": errResp.Type,
|
||||
"message": errResp.Message,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
// AnthropicErrorResponse returns an AnthropicResponse that causes the
|
||||
// test server to respond with the given HTTP status code and error.
|
||||
// This simulates provider errors like 529 Overloaded or 429 Rate Limited.
|
||||
func AnthropicErrorResponse(statusCode int, errorType, message string) AnthropicResponse {
|
||||
return AnthropicResponse{
|
||||
Error: &ErrorResponse{
|
||||
StatusCode: statusCode,
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicOverloadedResponse returns a 529 "overloaded" error matching
|
||||
// Anthropic's overloaded response format.
|
||||
func AnthropicOverloadedResponse() AnthropicResponse {
|
||||
return AnthropicErrorResponse(529, "overloaded_error", "Overloaded")
|
||||
}
|
||||
|
||||
// AnthropicRateLimitResponse returns a 429 rate limit error.
|
||||
func AnthropicRateLimitResponse() AnthropicResponse {
|
||||
return AnthropicErrorResponse(http.StatusTooManyRequests, "rate_limit_error", "Rate limited")
|
||||
}
|
||||
|
||||
// OpenAIErrorResponse returns an OpenAIResponse that causes the
|
||||
// test server to respond with the given HTTP status code and error.
|
||||
func OpenAIErrorResponse(statusCode int, errorType, message string) OpenAIResponse {
|
||||
return OpenAIResponse{
|
||||
Error: &ErrorResponse{
|
||||
StatusCode: statusCode,
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIRateLimitResponse returns a 429 rate limit error.
|
||||
func OpenAIRateLimitResponse() OpenAIResponse {
|
||||
return OpenAIErrorResponse(http.StatusTooManyRequests, "rate_limit_exceeded", "Rate limit exceeded")
|
||||
}
|
||||
|
||||
// OpenAIServerErrorResponse returns a 500 internal server error.
|
||||
func OpenAIServerErrorResponse() OpenAIResponse {
|
||||
return OpenAIErrorResponse(http.StatusInternalServerError, "server_error", "Internal server error")
|
||||
}
|
||||
@@ -20,6 +20,7 @@ type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
|
||||
type OpenAIResponse struct {
|
||||
StreamingChunks <-chan OpenAIChunk
|
||||
Response *OpenAICompletion
|
||||
Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON.
|
||||
}
|
||||
|
||||
// OpenAIRequest represents an OpenAI chat completion request.
|
||||
@@ -160,6 +161,11 @@ func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
@@ -184,6 +190,11 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
hasStreaming := resp.StreamingChunks != nil
|
||||
hasNonStreaming := resp.Response != nil
|
||||
|
||||
|
||||
+13
-5
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
@@ -65,7 +66,8 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
|
||||
// generateTitle calls the model with a title-generation system prompt
|
||||
// and returns the normalized result.
|
||||
// and returns the normalized result. It retries transient LLM errors
|
||||
// (rate limits, overloaded, etc.) with exponential backoff.
|
||||
func generateTitle(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -86,10 +88,16 @@ func generateTitle(
|
||||
},
|
||||
}
|
||||
toolChoice := fantasy.ToolChoiceNone
|
||||
response, err := model.Generate(ctx, fantasy.Call{
|
||||
Prompt: prompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
|
||||
var response *fantasy.Response
|
||||
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
|
||||
var genErr error
|
||||
response, genErr = model.Generate(retryCtx, fantasy.Call{
|
||||
Prompt: prompt,
|
||||
ToolChoice: &toolChoice,
|
||||
})
|
||||
return genErr
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate title text: %w", err)
|
||||
}
|
||||
|
||||
@@ -444,6 +444,7 @@ const (
|
||||
ChatStreamEventTypeStatus ChatStreamEventType = "status"
|
||||
ChatStreamEventTypeError ChatStreamEventType = "error"
|
||||
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
|
||||
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
|
||||
)
|
||||
|
||||
// ChatQueuedMessage represents a queued message waiting to be processed.
|
||||
@@ -470,6 +471,19 @@ type ChatStreamError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatStreamRetry represents an auto-retry status event in the stream.
|
||||
// Published when the server automatically retries a failed LLM call.
|
||||
type ChatStreamRetry struct {
|
||||
// Attempt is the 1-indexed retry attempt number.
|
||||
Attempt int `json:"attempt"`
|
||||
// DelayMs is the backoff delay in milliseconds before the retry.
|
||||
DelayMs int64 `json:"delay_ms"`
|
||||
// Error is the error message from the failed attempt.
|
||||
Error string `json:"error"`
|
||||
// RetryingAt is the timestamp when the retry will be attempted.
|
||||
RetryingAt time.Time `json:"retrying_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatStreamEvent represents a real-time update for chat streaming.
|
||||
type ChatStreamEvent struct {
|
||||
Type ChatStreamEventType `json:"type"`
|
||||
@@ -478,6 +492,7 @@ type ChatStreamEvent struct {
|
||||
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
|
||||
Status *ChatStreamStatus `json:"status,omitempty"`
|
||||
Error *ChatStreamError `json:"error,omitempty"`
|
||||
Retry *ChatStreamRetry `json:"retry,omitempty"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
Generated
+27
@@ -1516,6 +1516,7 @@ export interface ChatStreamEvent {
|
||||
readonly message_part?: ChatStreamMessagePart;
|
||||
readonly status?: ChatStreamStatus;
|
||||
readonly error?: ChatStreamError;
|
||||
readonly retry?: ChatStreamRetry;
|
||||
readonly queued_messages?: readonly ChatQueuedMessage[];
|
||||
}
|
||||
|
||||
@@ -1525,6 +1526,7 @@ export type ChatStreamEventType =
|
||||
| "message"
|
||||
| "message_part"
|
||||
| "queue_update"
|
||||
| "retry"
|
||||
| "status";
|
||||
|
||||
export const ChatStreamEventTypes: ChatStreamEventType[] = [
|
||||
@@ -1532,6 +1534,7 @@ export const ChatStreamEventTypes: ChatStreamEventType[] = [
|
||||
"message",
|
||||
"message_part",
|
||||
"queue_update",
|
||||
"retry",
|
||||
"status",
|
||||
];
|
||||
|
||||
@@ -1544,6 +1547,30 @@ export interface ChatStreamMessagePart {
|
||||
readonly part: ChatMessagePart;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatStreamRetry represents an auto-retry status event in the stream.
|
||||
* Published when the server automatically retries a failed LLM call.
|
||||
*/
|
||||
export interface ChatStreamRetry {
|
||||
/**
|
||||
* Attempt is the 1-indexed retry attempt number.
|
||||
*/
|
||||
readonly attempt: number;
|
||||
/**
|
||||
* DelayMs is the backoff delay in milliseconds before the retry.
|
||||
*/
|
||||
readonly delay_ms: number;
|
||||
/**
|
||||
* Error is the error message from the failed attempt.
|
||||
*/
|
||||
readonly error: string;
|
||||
/**
|
||||
* RetryingAt is the timestamp when the retry will be attempted.
|
||||
*/
|
||||
readonly retrying_at: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatStreamStatus represents an updated chat status.
|
||||
|
||||
@@ -35,6 +35,7 @@ import {
|
||||
selectMessagesByID,
|
||||
selectOrderedMessageIDs,
|
||||
selectQueuedMessages,
|
||||
selectRetryState,
|
||||
selectStreamError,
|
||||
selectStreamState,
|
||||
selectSubagentStatusOverrides,
|
||||
@@ -121,6 +122,7 @@ const AgentDetailTimeline: FC<AgentDetailTimelineProps> = ({
|
||||
store,
|
||||
selectSubagentStatusOverrides,
|
||||
);
|
||||
const retryState = useChatSelector(store, selectRetryState);
|
||||
|
||||
const messages = useMemo(
|
||||
() =>
|
||||
@@ -172,6 +174,7 @@ const AgentDetailTimeline: FC<AgentDetailTimelineProps> = ({
|
||||
streamTools={streamTools}
|
||||
subagentTitles={subagentTitles}
|
||||
subagentStatusOverrides={subagentStatusOverrides}
|
||||
retryState={retryState}
|
||||
isAwaitingFirstStreamChunk={isAwaitingFirstStreamChunk}
|
||||
detailErrorMessage={detailErrorMessage}
|
||||
onEditUserMessage={onEditUserMessage}
|
||||
|
||||
@@ -141,6 +141,7 @@ type ChatStoreState = {
|
||||
streamState: StreamState | null;
|
||||
chatStatus: TypesGen.ChatStatus | null;
|
||||
streamError: string | null;
|
||||
retryState: { attempt: number; error: string } | null;
|
||||
queuedMessages: readonly TypesGen.ChatQueuedMessage[];
|
||||
subagentStatusOverrides: Map<string, TypesGen.ChatStatus>;
|
||||
};
|
||||
@@ -163,6 +164,8 @@ type ChatStore = {
|
||||
setChatStatus: (status: TypesGen.ChatStatus | null) => void;
|
||||
setStreamError: (reason: string | null) => void;
|
||||
clearStreamError: () => void;
|
||||
setRetryState: (state: { attempt: number; error: string } | null) => void;
|
||||
clearRetryState: () => void;
|
||||
clearStreamState: () => void;
|
||||
setSubagentStatusOverride: (
|
||||
chatID: string,
|
||||
@@ -177,6 +180,7 @@ const createInitialState = (): ChatStoreState => ({
|
||||
streamState: null,
|
||||
chatStatus: null,
|
||||
streamError: null,
|
||||
retryState: null,
|
||||
queuedMessages: [],
|
||||
subagentStatusOverrides: new Map(),
|
||||
});
|
||||
@@ -313,6 +317,24 @@ const createChatStore = (): ChatStore => {
|
||||
streamError: null,
|
||||
}));
|
||||
},
|
||||
setRetryState: (retryState) => {
|
||||
if (state.retryState === retryState) {
|
||||
return;
|
||||
}
|
||||
setState((current) => ({
|
||||
...current,
|
||||
retryState,
|
||||
}));
|
||||
},
|
||||
clearRetryState: () => {
|
||||
if (state.retryState === null) {
|
||||
return;
|
||||
}
|
||||
setState((current) => ({
|
||||
...current,
|
||||
retryState: null,
|
||||
}));
|
||||
},
|
||||
clearStreamState: () => {
|
||||
if (state.streamState === null) {
|
||||
return;
|
||||
@@ -337,6 +359,7 @@ const createChatStore = (): ChatStore => {
|
||||
if (
|
||||
state.streamState === null &&
|
||||
state.streamError === null &&
|
||||
state.retryState === null &&
|
||||
state.subagentStatusOverrides.size === 0
|
||||
) {
|
||||
return;
|
||||
@@ -345,6 +368,7 @@ const createChatStore = (): ChatStore => {
|
||||
...current,
|
||||
streamState: null,
|
||||
streamError: null,
|
||||
retryState: null,
|
||||
subagentStatusOverrides: new Map(),
|
||||
}));
|
||||
},
|
||||
@@ -373,6 +397,7 @@ export const selectQueuedMessages = (state: ChatStoreState) =>
|
||||
state.queuedMessages;
|
||||
export const selectSubagentStatusOverrides = (state: ChatStoreState) =>
|
||||
state.subagentStatusOverrides;
|
||||
export const selectRetryState = (state: ChatStoreState) => state.retryState;
|
||||
|
||||
export const useChatStore = (
|
||||
options: UseChatStoreOptions,
|
||||
@@ -612,6 +637,10 @@ export const useChatStore = (
|
||||
store.setChatStatus(nextStatus);
|
||||
if (nextStatus === "pending" || nextStatus === "waiting") {
|
||||
store.clearStreamState();
|
||||
store.clearRetryState();
|
||||
}
|
||||
if (nextStatus === "running") {
|
||||
store.clearRetryState();
|
||||
}
|
||||
if (nextStatus !== "error") {
|
||||
clearChatErrorReason(chatID);
|
||||
@@ -630,6 +659,7 @@ export const useChatStore = (
|
||||
asString(error?.message).trim() || "Chat processing failed.";
|
||||
store.setChatStatus("error");
|
||||
store.setStreamError(reason);
|
||||
store.clearRetryState();
|
||||
setChatErrorReason(chatID, reason);
|
||||
updateSidebarChat((chat) => ({
|
||||
...chat,
|
||||
@@ -638,6 +668,16 @@ export const useChatStore = (
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
case "retry": {
|
||||
const retry = streamEvent.retry;
|
||||
if (retry) {
|
||||
store.setRetryState({
|
||||
attempt: retry.attempt,
|
||||
error: retry.error,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -307,12 +307,13 @@ const ChatMessageItem = memo<{
|
||||
);
|
||||
ChatMessageItem.displayName = "ChatMessageItem";
|
||||
|
||||
const StreamingOutput = memo<{
|
||||
export const StreamingOutput = memo<{
|
||||
streamState: StreamState | null;
|
||||
streamTools: readonly MergedTool[];
|
||||
subagentTitles?: Map<string, string>;
|
||||
subagentStatusOverrides?: Map<string, TypesGen.ChatStatus>;
|
||||
showInitialPlaceholder?: boolean;
|
||||
retryState?: { attempt: number; error: string } | null;
|
||||
}>(
|
||||
({
|
||||
streamState,
|
||||
@@ -320,6 +321,7 @@ const StreamingOutput = memo<{
|
||||
subagentTitles,
|
||||
subagentStatusOverrides,
|
||||
showInitialPlaceholder = false,
|
||||
retryState,
|
||||
}) => {
|
||||
const conversationItemProps = { role: "assistant" as const };
|
||||
const toolByID = new Map(streamTools.map((tool) => [tool.id, tool]));
|
||||
@@ -348,12 +350,17 @@ const StreamingOutput = memo<{
|
||||
streamTools.length === 0) ? (
|
||||
<div className="relative">
|
||||
<Response aria-hidden className="invisible">
|
||||
Thinking...
|
||||
{`Thinking...${retryState ? ` attempt ${retryState.attempt}` : ""}`}
|
||||
</Response>
|
||||
<div className="pointer-events-none absolute inset-0">
|
||||
<div className="pointer-events-none absolute inset-0 flex items-baseline gap-2">
|
||||
<Shimmer as="div" className="text-[13px] leading-relaxed">
|
||||
Thinking...
|
||||
</Shimmer>
|
||||
{retryState && (
|
||||
<span className="text-[11px] text-content-secondary">
|
||||
attempt {retryState.attempt}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
@@ -598,6 +605,7 @@ type ConversationTimelineProps = {
|
||||
streamTools: readonly MergedTool[];
|
||||
subagentTitles: Map<string, string>;
|
||||
subagentStatusOverrides: Map<string, TypesGen.ChatStatus>;
|
||||
retryState?: { attempt: number; error: string } | null;
|
||||
isAwaitingFirstStreamChunk: boolean;
|
||||
detailErrorMessage?: string | null;
|
||||
onEditUserMessage?: (messageId: number, text: string) => void;
|
||||
@@ -615,6 +623,7 @@ export const ConversationTimeline: FC<ConversationTimelineProps> = ({
|
||||
streamTools,
|
||||
subagentTitles,
|
||||
subagentStatusOverrides,
|
||||
retryState,
|
||||
isAwaitingFirstStreamChunk,
|
||||
detailErrorMessage,
|
||||
onEditUserMessage,
|
||||
@@ -677,6 +686,7 @@ export const ConversationTimeline: FC<ConversationTimelineProps> = ({
|
||||
subagentTitles={subagentTitles}
|
||||
subagentStatusOverrides={subagentStatusOverrides}
|
||||
showInitialPlaceholder={isAwaitingFirstStreamChunk}
|
||||
retryState={retryState}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -689,6 +699,7 @@ export const ConversationTimeline: FC<ConversationTimelineProps> = ({
|
||||
subagentTitles={subagentTitles}
|
||||
subagentStatusOverrides={subagentStatusOverrides}
|
||||
showInitialPlaceholder={isAwaitingFirstStreamChunk}
|
||||
retryState={retryState}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { StreamingOutput } from "./ConversationTimeline";
|
||||
|
||||
// StreamingOutput renders inside a ConversationItem > Message > MessageContent
|
||||
// chain, but it's self-contained enough to render standalone.
|
||||
|
||||
const meta: Meta<typeof StreamingOutput> = {
|
||||
title: "pages/AgentsPage/AgentDetail/StreamingOutput",
|
||||
component: StreamingOutput,
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div className="mx-auto w-full max-w-3xl py-6">
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof StreamingOutput>;
|
||||
|
||||
/** Default shimmer placeholder with no stream state. */
|
||||
export const ThinkingPlaceholder: Story = {
|
||||
args: {
|
||||
streamState: null,
|
||||
streamTools: [],
|
||||
showInitialPlaceholder: true,
|
||||
},
|
||||
};
|
||||
|
||||
/** First retry attempt. */
|
||||
export const RetryAttempt1: Story = {
|
||||
args: {
|
||||
streamState: null,
|
||||
streamTools: [],
|
||||
showInitialPlaceholder: true,
|
||||
retryState: { attempt: 1, error: "service unavailable" },
|
||||
},
|
||||
};
|
||||
|
||||
/** Third retry attempt. */
|
||||
export const RetryAttempt3: Story = {
|
||||
args: {
|
||||
streamState: null,
|
||||
streamTools: [],
|
||||
showInitialPlaceholder: true,
|
||||
retryState: { attempt: 3, error: "rate limit exceeded" },
|
||||
},
|
||||
};
|
||||
|
||||
/** Higher attempt number to see how it looks. */
|
||||
export const RetryHighAttempt: Story = {
|
||||
args: {
|
||||
streamState: null,
|
||||
streamTools: [],
|
||||
showInitialPlaceholder: true,
|
||||
retryState: { attempt: 12, error: "overloaded" },
|
||||
},
|
||||
};
|
||||
|
||||
/** Active streaming with partial text content. */
|
||||
export const StreamingWithText: Story = {
|
||||
args: {
|
||||
streamState: {
|
||||
blocks: [
|
||||
{
|
||||
type: "response" as const,
|
||||
text: "Here is a partial response that is still being generated...",
|
||||
},
|
||||
],
|
||||
toolCalls: {},
|
||||
toolResults: {},
|
||||
},
|
||||
streamTools: [],
|
||||
},
|
||||
};
|
||||
|
||||
/** Content arrived after retries (no retry indicator shown). */
|
||||
export const StreamingAfterRetry: Story = {
|
||||
args: {
|
||||
streamState: {
|
||||
blocks: [
|
||||
{
|
||||
type: "response" as const,
|
||||
text: "Successfully connected after retry. Here is your answer...",
|
||||
},
|
||||
],
|
||||
toolCalls: {},
|
||||
toolResults: {},
|
||||
},
|
||||
streamTools: [],
|
||||
retryState: null,
|
||||
},
|
||||
};
|
||||
Reference in New Issue
Block a user